diff --git a/models/spring-ai-wenxin/README.md b/models/spring-ai-wenxin/README.md new file mode 100644 index 00000000000..e69de29bb2d diff --git a/models/spring-ai-wenxin/pom.xml b/models/spring-ai-wenxin/pom.xml new file mode 100644 index 00000000000..d41e5e0bf92 --- /dev/null +++ b/models/spring-ai-wenxin/pom.xml @@ -0,0 +1,106 @@ + + + 4.0.0 + + org.springframework.ai + spring-ai + 1.0.0-SNAPSHOT + ../../pom.xml + + spring-ai-wenxin + jar + Spring AI Model - Wenxin + Wenxin support + https://github.com/spring-projects/spring-ai + + + https://github.com/spring-projects/spring-ai + git://github.com/spring-projects/spring-ai.git + git@github.com:spring-projects/spring-ai.git + + + + + + + commons-codec + commons-codec + + + + + org.springframework.ai + spring-ai-core + ${project.parent.version} + + + + org.springframework.ai + spring-ai-retry + ${project.parent.version} + + + + + org.springframework.boot + spring-boot + + + + io.rest-assured + json-path + + + + + com.github.victools + jsonschema-generator + ${victools.version} + + + + com.github.victools + jsonschema-module-jackson + ${victools.version} + + + + org.springframework + spring-context-support + + + org.springframework.boot + spring-boot-starter-logging + + + + + org.springframework.ai + spring-ai-test + ${project.version} + test + + + + org.testcontainers + testcontainers + test + + + + org.testcontainers + junit-jupiter + test + + + + + + maven_central + Maven Central + https://repo.maven.apache.org/maven2/ + + + + diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinChatModel.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinChatModel.java new file mode 100644 index 00000000000..172b19b44f1 --- /dev/null +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinChatModel.java @@ -0,0 +1,273 @@ +package org.springframework.ai.wenxin; + +import org.apache.commons.lang3.StringUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.metadata.RateLimit; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.model.StreamingChatModel; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.model.function.AbstractFunctionCallSupport; +import org.springframework.ai.model.function.FunctionCallbackContext; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.ai.wenxin.api.WenxinApi; +import org.springframework.ai.wenxin.metadata.WenxinUsage; +import org.springframework.ai.wenxin.metadata.support.WenxinResponseHeaderExtractor; +import org.springframework.http.ResponseEntity; +import org.springframework.retry.support.RetryTemplate; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import reactor.core.publisher.Flux; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +/** + * @author lvchzh + * @since 1.0.0 + */ +public class WenxinChatModel extends + AbstractFunctionCallSupport> + implements ChatModel, StreamingChatModel { + + // @formatter:off + private static final Logger logger = LoggerFactory.getLogger(WenxinChatModel.class); + private final RetryTemplate retryTemplate; + private final WenxinApi wenxinApi; + private WenxinChatOptions defaultOptions; + + public WenxinChatModel(WenxinApi wenxinApi) { + this(wenxinApi, + WenxinChatOptions.builder().withModel(WenxinApi.DEFAULT_CHAT_MODEL).withTemperature(0.7f).build()); + } + + public WenxinChatModel(WenxinApi wenxinApi, WenxinChatOptions options) { + this(wenxinApi, options, null, RetryUtils.DEFAULT_RETRY_TEMPLATE); + } + + public WenxinChatModel(WenxinApi wenxinApi, WenxinChatOptions options, + FunctionCallbackContext functionCallbackContext, RetryTemplate retryTemplate) { + super(functionCallbackContext); + Assert.notNull(wenxinApi, "WenxinApi must not be null"); + Assert.notNull(options, "WenxinChatOptions must not be null"); + Assert.notNull(retryTemplate, "RetryTemplate must not be null"); + this.wenxinApi = wenxinApi; + this.defaultOptions = options; + this.retryTemplate = retryTemplate; + } + + @Override + public ChatResponse call(Prompt prompt) { + + WenxinApi.ChatCompletionRequest request = createRequest(prompt, false); + return this.retryTemplate.execute(ctx -> { + + ResponseEntity completionEntity = this.callWithFunctionSupport(request); + + var chatCompletion = completionEntity.getBody(); + + if (chatCompletion == null) { + logger.warn("No chat completion returned for prompt: {}", prompt); + return new ChatResponse(List.of()); + } + + RateLimit rateLimit = WenxinResponseHeaderExtractor.extractAiResponseHeaders(completionEntity); + + Generation generation = new Generation(chatCompletion.result(), toMap(chatCompletion.id(), + chatCompletion)); + + List generations = List.of(generation); + + return new ChatResponse(generations, + from(chatCompletion,rateLimit,request)); + }); + } + + public static ChatResponseMetadata from(WenxinApi.ChatCompletion result, RateLimit rateLimit,WenxinApi.ChatCompletionRequest request) { + Assert.notNull(result, "Wenxin ChatCompletionResult must not be null"); + return ChatResponseMetadata.builder() + .withId(result.id()) + .withUsage(WenxinUsage.from(result.usage())) + .withModel(request.model()) + .withRateLimit(rateLimit) + .withKeyValue("created", result.created()) + .withKeyValue("sentence_id",result.sentenceId()) + .build(); + + } + + @Override + public ChatOptions getDefaultOptions() { + return WenxinChatOptions.fromOptions(this.defaultOptions); + } + + private Map toMap(String id, WenxinApi.ChatCompletion chatCompletion) { + Map map = new HashMap<>(); + if (chatCompletion.finishReason() != null) { + map.put("finishReason", chatCompletion.finishReason().name()); + } + map.put("id", id); + return map; + } + + @Override + public Flux stream(Prompt prompt) { + WenxinApi.ChatCompletionRequest request = createRequest(prompt, true); + + return this.retryTemplate.execute(ctx -> { + + Flux completionChunks = this.wenxinApi.chatCompletionStream(request); + + return completionChunks.map(chunk -> chunkToChatCompletion(chunk)).map(chatCompletion -> { + try { + chatCompletion = handleFunctionCallOrReturn(request, + ResponseEntity.of(Optional.of(chatCompletion))).getBody(); + + @SuppressWarnings("null") + String id = chatCompletion.id(); + String finish = chatCompletion.finishReason() != null ? chatCompletion.finishReason().name() : + null; + + var generation = new Generation(chatCompletion.result(), Map.of("id", id, "finishReason", finish)); + if (chatCompletion.finishReason() != null) { + generation = generation.withGenerationMetadata( + ChatGenerationMetadata.from(chatCompletion.finishReason().name(), null)); + } + List generations = List.of(generation); + + return new ChatResponse(generations); + } catch (Exception e) { + logger.error("Error processing chat completion", e); + return new ChatResponse(List.of()); + } + }); + }); + } + + private WenxinApi.ChatCompletion chunkToChatCompletion(WenxinApi.ChatCompletionChunk chunk) { + return new WenxinApi.ChatCompletion(chunk.id(), "chat.completion", chunk.created(), chunk.sentenceId(), + chunk.isEnd(), chunk.isTruncated(), chunk.finishReason(), chunk.searchInfo(), chunk.result(), + chunk.needClearHistory(), chunk.flag(), chunk.banRound(), null, chunk.functionCall()); + } + + WenxinApi.ChatCompletionRequest createRequest(Prompt prompt, boolean stream) { + + Set functionsForThisRequest = new HashSet<>(); + + List chatCompletionMessages = prompt.getInstructions().stream() + .map(m -> new WenxinApi.ChatCompletionMessage(m.getContent(), + WenxinApi.Role.valueOf(m.getMessageType().name()))).toList(); + WenxinApi.ChatCompletionRequest request = new WenxinApi.ChatCompletionRequest(chatCompletionMessages, stream); + + if (prompt.getOptions() != null) { + + WenxinChatOptions updateRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), + ChatOptions.class, WenxinChatOptions.class); + + Set promptEnabledFunctions = this.handleFunctionCallbackConfigurations(updateRuntimeOptions, + IS_RUNTIME_CALL); + + functionsForThisRequest.addAll(promptEnabledFunctions); + + request = ModelOptionsUtils.merge(updateRuntimeOptions, request, + WenxinApi.ChatCompletionRequest.class); + } else { + throw new IllegalArgumentException("Prompt options are not of type ChatOptions: " + + prompt.getOptions().getClass().getSimpleName()); + } + + if (this.defaultOptions != null) { + + Set defaultEnableFunctions = this.handleFunctionCallbackConfigurations(this.defaultOptions, + !IS_RUNTIME_CALL); + + functionsForThisRequest.addAll(defaultEnableFunctions); + + request = ModelOptionsUtils.merge(this.defaultOptions, request, WenxinApi.ChatCompletionRequest.class); + } + + if (!CollectionUtils.isEmpty(functionsForThisRequest)) { + request = ModelOptionsUtils.merge( + WenxinChatOptions.builder().withTools(this.getFunctionTools(functionsForThisRequest)).build(), + request, WenxinApi.ChatCompletionRequest.class); + } + + return request; + } + + private List getFunctionTools(Set functionNames) { + return this.resolveFunctionCallbacks(functionNames).stream().map(functioncallback -> { + var function = new WenxinApi.FunctionTool(functioncallback.getName(), functioncallback.getDescription(), + functioncallback.getInputTypeSchema(), null); + return function; + }).toList(); + } + + @Override + protected WenxinApi.ChatCompletionRequest doCreateToolResponseRequest( + WenxinApi.ChatCompletionRequest previousRequest, WenxinApi.ChatCompletionMessage responseMessage, + List conversationHistory) { + + var functionName = responseMessage.functionCall().name(); + String functionArguments = responseMessage.functionCall().arguments(); + if (!this.functionCallbackRegister.containsKey(functionName)) { + throw new IllegalStateException("Function callback not found for function name: " + functionName); + } + + String functionResponse = this.functionCallbackRegister.get(functionName).call(functionArguments); + + conversationHistory.add( + new WenxinApi.ChatCompletionMessage(functionResponse, WenxinApi.Role.FUNCTION, functionName, null)); + + WenxinApi.ChatCompletionRequest newRequest = new WenxinApi.ChatCompletionRequest(conversationHistory, false); + newRequest = ModelOptionsUtils.merge(newRequest, previousRequest, WenxinApi.ChatCompletionRequest.class); + return newRequest; + } + + @Override + protected List doGetUserMessages(WenxinApi.ChatCompletionRequest request) { + return request.messages(); + } + + @Override + protected WenxinApi.ChatCompletionMessage doGetToolResponseMessage( + ResponseEntity chatCompletion) { + return new WenxinApi.ChatCompletionMessage(chatCompletion.getBody().result(), WenxinApi.Role.ASSISTANT, null, + chatCompletion.getBody().functionCall()); + } + + @Override + protected ResponseEntity doChatCompletion(WenxinApi.ChatCompletionRequest request) { + return this.wenxinApi.chatCompletionEntity(request); + } + + @Override + protected Flux> doChatCompletionStream( + WenxinApi.ChatCompletionRequest request) { + return this.wenxinApi.chatCompletionStream(request) + .map(this::chunkToChatCompletion) + .map(Optional::ofNullable) + .map(ResponseEntity::of); + } + + @Override + protected boolean isToolFunctionCall(ResponseEntity chatCompletion) { + var body = chatCompletion.getBody(); + if (body == null) { + return false; + } + return body.finishReason() == WenxinApi.ChatCompletionFinishReason.FUNCTION_CALL; + } + // @formatter:on + +} diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinChatOptions.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinChatOptions.java new file mode 100644 index 00000000000..47fb3a8b5f7 --- /dev/null +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinChatOptions.java @@ -0,0 +1,507 @@ +package org.springframework.ai.wenxin; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.wenxin.api.WenxinApi; +import org.springframework.boot.context.properties.NestedConfigurationProperty; +import org.springframework.util.Assert; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +/** + * @author lvchzh + * @since 1.0.0 + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class WenxinChatOptions implements FunctionCallingOptions, ChatOptions { + + // + private @JsonProperty("model") String model; + + private @JsonProperty("penalty_score") Float penaltyScore; + + private @JsonProperty("max_output_tokens") Integer maxOutputTokens; + + private @JsonProperty("response_format") WenxinApi.ChatCompletionRequest.ResponseFormat responseFormat; + + private @JsonProperty("stop") List stop; + + private @JsonProperty("temperature") Float temperature; + + private @JsonProperty("top_p") Float topP; + + private @JsonProperty("functions") List tools; + + private @JsonProperty("tool_choice") String toolChoice; + + private @JsonProperty("user_id") String userId; + + private @JsonProperty("system") String system; + + private @JsonProperty("disable_search") Boolean disableSearch; + + private @JsonProperty("enable_citation") Boolean enableCitation; + + private @JsonProperty("enable_trace") Boolean enableTrace; + + @NestedConfigurationProperty + @JsonIgnore + private List functionCallbacks = new ArrayList<>(); + + @NestedConfigurationProperty + @JsonIgnore + private Set functions = new HashSet<>(); + + public static Builder builder() { + return new Builder(); + } + + // @formatter:off + public static WenxinChatOptions fromOptions(WenxinChatOptions fromOptions) { + return WenxinChatOptions.builder() + .withModel(fromOptions.getModel()) + .withPenaltyScore(fromOptions.getPenaltyScore()) + .withMaxOutputTokens(fromOptions.getMaxOutputTokens()) + .withResponseFormat(fromOptions.getResponseFormat()) + .withStop(fromOptions.getStop()) + .withTemperature(fromOptions.getTemperature()) + .withTopP(fromOptions.getTopP()) + .withTools(fromOptions.getTools()) + .withToolChoice(fromOptions.getToolChoice()) + .withUserId(fromOptions.getUserId()) + .withSystem(fromOptions.getSystem()) + .withDisableSearch(fromOptions.getDisableSearch()) + .withEnableCitation(fromOptions.getEnableCitation()) + .withEnableTrace(fromOptions.getEnableTrace()) + .withFunctionCallbacks(fromOptions.getFunctionCallbacks()) + .withFunctions(fromOptions.getFunctions()) + .build(); + } + // @formatter:on + + public String getModel() { + return model; + } + + @Override + public Float getFrequencyPenalty() { + return 0f; + } + + @Override + public Integer getMaxTokens() { + return 0; + } + + @Override + public Float getPresencePenalty() { + return 0f; + } + + @Override + public List getStopSequences() { + return List.of(); + } + + public void setModel(String model) { + this.model = model; + } + + public Float getPenaltyScore() { + return penaltyScore; + } + + public void setPenaltyScore(Float penaltyScore) { + this.penaltyScore = penaltyScore; + } + + public Integer getMaxOutputTokens() { + return maxOutputTokens; + } + + public void setMaxOutputTokens(Integer maxOutputTokens) { + this.maxOutputTokens = maxOutputTokens; + } + + public WenxinApi.ChatCompletionRequest.ResponseFormat getResponseFormat() { + return responseFormat; + } + + public void setResponseFormat(WenxinApi.ChatCompletionRequest.ResponseFormat responseFormat) { + this.responseFormat = responseFormat; + } + + public List getStop() { + return stop; + } + + public void setStop(List stop) { + this.stop = stop; + } + + @Override + public Float getTemperature() { + return this.temperature; + } + + public void setTemperature(Float temperature) { + this.temperature = temperature; + } + + @Override + public Float getTopP() { + return this.topP; + } + + public void setTopP(Float topP) { + this.topP = topP; + } + + @Override + @JsonIgnore + public Integer getTopK() { + throw new UnsupportedOperationException("Unimplemented method 'getTopK'"); + } + + @Override + public ChatOptions copy() { + return WenxinChatOptions.fromOptions(this); + } + + @JsonIgnore + public void setTopK(Integer topK) { + throw new UnsupportedOperationException("Unimplemented method 'setTopK'"); + } + + public List getTools() { + return tools; + } + + public void setTools(List tools) { + this.tools = tools; + } + + public String getToolChoice() { + return toolChoice; + } + + public void setToolChoice(String toolChoice) { + this.toolChoice = toolChoice; + } + + public String getUserId() { + return userId; + } + + public void setUserId(String userId) { + this.userId = userId; + } + + public String getSystem() { + return system; + } + + public void setSystem(String system) { + this.system = system; + } + + public Boolean getDisableSearch() { + return disableSearch; + } + + public void setDisableSearch(Boolean disableSearch) { + this.disableSearch = disableSearch; + } + + public Boolean getEnableCitation() { + return enableCitation; + } + + public void setEnableCitation(Boolean enableCitation) { + this.enableCitation = enableCitation; + } + + public Boolean getEnableTrace() { + return enableTrace; + } + + public void setEnableTrace(Boolean enableTrace) { + this.enableTrace = enableTrace; + } + + @Override + public List getFunctionCallbacks() { + return this.functionCallbacks; + } + + @Override + public void setFunctionCallbacks(List functionCallbacks) { + this.functionCallbacks = functionCallbacks; + } + + @Override + public Set getFunctions() { + return functions; + } + + @Override + public void setFunctions(Set functionNames) { + this.functions = functionNames; + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + ((model == null) ? 0 : model.hashCode()); + result = prime * result + ((penaltyScore == null) ? 0 : penaltyScore.hashCode()); + result = prime * result + ((maxOutputTokens == null) ? 0 : maxOutputTokens.hashCode()); + result = prime * result + ((responseFormat == null) ? 0 : responseFormat.hashCode()); + result = prime * result + ((stop == null) ? 0 : stop.hashCode()); + result = prime * result + ((temperature == null) ? 0 : temperature.hashCode()); + result = prime * result + ((topP == null) ? 0 : topP.hashCode()); + result = prime * result + ((tools == null) ? 0 : tools.hashCode()); + result = prime * result + ((toolChoice == null) ? 0 : toolChoice.hashCode()); + result = prime * result + ((userId == null) ? 0 : userId.hashCode()); + result = prime * result + ((system == null) ? 0 : system.hashCode()); + result = prime * result + ((disableSearch == null) ? 0 : disableSearch.hashCode()); + result = prime * result + ((enableCitation == null) ? 0 : enableCitation.hashCode()); + result = prime * result + ((enableTrace == null) ? 0 : enableTrace.hashCode()); + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null) { + return false; + } + if (getClass() != obj.getClass()) { + return false; + } + WenxinChatOptions other = (WenxinChatOptions) obj; + if (this.model == null) { + if (other.model != null) { + return false; + } + } + else if (!this.model.equals(other.model)) { + return false; + } + if (this.penaltyScore == null) { + if (other.penaltyScore != null) { + return false; + } + } + else if (!this.penaltyScore.equals(other.penaltyScore)) { + return false; + } + if (this.maxOutputTokens == null) { + if (other.maxOutputTokens != null) { + return false; + } + } + else if (!this.maxOutputTokens.equals(other.maxOutputTokens)) { + return false; + } + if (this.responseFormat != other.responseFormat) { + return false; + } + if (this.stop == null) { + if (other.stop != null) { + return false; + } + } + else if (!this.stop.equals(other.stop)) { + return false; + } + if (this.temperature == null) { + if (other.temperature != null) { + return false; + } + } + else if (!this.temperature.equals(other.temperature)) { + return false; + } + if (this.topP == null) { + if (other.topP != null) { + return false; + } + } + else if (!this.topP.equals(other.topP)) { + return false; + } + if (this.tools == null) { + if (other.tools != null) { + return false; + } + } + else if (!this.tools.equals(other.tools)) { + return false; + } + if (this.toolChoice == null) { + if (other.toolChoice != null) { + return false; + } + } + else if (!this.toolChoice.equals(other.toolChoice)) { + return false; + } + if (this.userId == null) { + if (other.userId != null) { + return false; + } + } + else if (!this.userId.equals(other.userId)) { + return false; + } + if (this.system == null) { + if (other.system != null) { + return false; + } + } + else if (!this.system.equals(other.system)) { + return false; + } + if (this.disableSearch == null) { + if (other.disableSearch != null) { + return false; + } + } + else if (!this.disableSearch.equals(other.disableSearch)) { + return false; + } + if (this.enableCitation == null) { + if (other.enableCitation != null) { + return false; + } + } + else if (!this.enableCitation.equals(other.enableCitation)) { + return false; + } + if (this.enableTrace == null) { + if (other.enableTrace != null) { + return false; + } + } + else if (!this.enableTrace.equals(other.enableTrace)) { + return false; + } + return true; + } + + public static class Builder { + + protected WenxinChatOptions options; + + public Builder() { + this.options = new WenxinChatOptions(); + } + + public Builder(WenxinChatOptions options) { + this.options = options; + } + + public Builder withModel(String model) { + this.options.model = model; + return this; + } + + public Builder withPenaltyScore(Float penaltyScore) { + this.options.penaltyScore = penaltyScore; + return this; + } + + public Builder withMaxOutputTokens(Integer maxOutputTokens) { + this.options.maxOutputTokens = maxOutputTokens; + return this; + } + + public Builder withResponseFormat(WenxinApi.ChatCompletionRequest.ResponseFormat responseFormat) { + this.options.responseFormat = responseFormat; + return this; + } + + public Builder withStop(List stop) { + this.options.stop = stop; + return this; + } + + public Builder withTemperature(Float temperature) { + this.options.temperature = temperature; + return this; + } + + public Builder withTopP(Float topP) { + this.options.topP = topP; + return this; + } + + public Builder withTools(List tools) { + this.options.tools = tools; + return this; + } + + public Builder withToolChoice(String toolChoice) { + this.options.toolChoice = toolChoice; + return this; + } + + public Builder withUserId(String userId) { + this.options.userId = userId; + return this; + } + + public Builder withSystem(String system) { + this.options.system = system; + return this; + } + + public Builder withDisableSearch(Boolean disableSearch) { + this.options.disableSearch = disableSearch; + return this; + } + + public Builder withEnableCitation(Boolean enableCitation) { + this.options.enableCitation = enableCitation; + return this; + } + + public Builder withEnableTrace(Boolean enableTrace) { + this.options.enableTrace = enableTrace; + return this; + } + + public Builder withFunctionCallbacks(List functionCallbacks) { + this.options.functionCallbacks = functionCallbacks; + return this; + } + + public Builder withFunctions(Set functionNames) { + Assert.notNull(functionNames, "Function names must not be null"); + this.options.functions = functionNames; + return this; + } + + public Builder withFunction(String functionName) { + Assert.hasText(functionName, "Function name must not be empty"); + this.options.functions.add(functionName); + return this; + } + + public WenxinChatOptions build() { + return this.options; + } + + } + +} diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinEmbeddingModel.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinEmbeddingModel.java new file mode 100644 index 00000000000..833af199a3a --- /dev/null +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinEmbeddingModel.java @@ -0,0 +1,112 @@ +package org.springframework.ai.wenxin; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.document.Document; +import org.springframework.ai.document.MetadataMode; +import org.springframework.ai.embedding.AbstractEmbeddingModel; +import org.springframework.ai.embedding.Embedding; +import org.springframework.ai.embedding.EmbeddingOptions; +import org.springframework.ai.embedding.EmbeddingRequest; +import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.ai.embedding.EmbeddingResponseMetadata; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.ai.wenxin.api.WenxinApi; +import org.springframework.ai.wenxin.metadata.WenxinUsage; +import org.springframework.retry.support.RetryTemplate; +import org.springframework.util.Assert; + +import java.util.List; + +/** + * @author lvchzh + * @since 1.0.0 + */ +public class WenxinEmbeddingModel extends AbstractEmbeddingModel { + + private static final Logger logger = LoggerFactory.getLogger(WenxinEmbeddingModel.class); + + private final WenxinEmbeddingOptions defaultOptions; + + private final RetryTemplate retryTemplate; + + private final WenxinApi wenxinApi; + + private final MetadataMode metadataMode; + + public WenxinEmbeddingModel(WenxinApi wenxinApi) { + this(wenxinApi, MetadataMode.EMBED); + } + + public WenxinEmbeddingModel(WenxinApi wenxinApi, MetadataMode metadataMode) { + this(wenxinApi, metadataMode, + WenxinEmbeddingOptions.builder().withModel(WenxinApi.DEFAULT_EMBEDDING_MODEL).build(), + RetryUtils.DEFAULT_RETRY_TEMPLATE); + } + + public WenxinEmbeddingModel(WenxinApi wenxinApi, MetadataMode metadataMode, + WenxinEmbeddingOptions wenxinEmbeddingOptions) { + this(wenxinApi, metadataMode, wenxinEmbeddingOptions, RetryUtils.DEFAULT_RETRY_TEMPLATE); + } + + public WenxinEmbeddingModel(WenxinApi wenxinApi, MetadataMode metadataMode, WenxinEmbeddingOptions options, + RetryTemplate retryTemplate) { + Assert.notNull(wenxinApi, "WenxinApi must not be null"); + Assert.notNull(metadataMode, "MetadataMode must not be null"); + Assert.notNull(options, "WenxinEmbeddingOptions must not be null"); + Assert.notNull(retryTemplate, "RetryTemplate must not be null"); + + this.wenxinApi = wenxinApi; + this.metadataMode = metadataMode; + this.defaultOptions = options; + this.retryTemplate = retryTemplate; + } + + @Override + public float[] embed(Document document) { + Assert.notNull(document, "Document must not be null"); + return this.embed(document.getFormattedContent(this.metadataMode)); + } + + @SuppressWarnings("unchecked") + @Override + public EmbeddingResponse call(EmbeddingRequest request) { + + return this.retryTemplate.execute(ctx -> { + + WenxinApi.EmbeddingRequest> apiRequest = (this.defaultOptions != null) + ? new WenxinApi.EmbeddingRequest<>(request.getInstructions(), this.defaultOptions.getModel(), + this.defaultOptions.getUserId()) + : new WenxinApi.EmbeddingRequest<>(request.getInstructions(), WenxinApi.DEFAULT_EMBEDDING_MODEL); + + if (request.getOptions() != null) { + apiRequest = ModelOptionsUtils.merge(request.getOptions(), apiRequest, + WenxinApi.EmbeddingRequest.class); + } + + WenxinApi.EmbeddingList apiEmbeddingResponse = this.wenxinApi.embeddings(apiRequest) + .getBody(); + + if (apiEmbeddingResponse == null) { + logger.warn("No embeddings returned from request: {}", request); + return new EmbeddingResponse(List.of()); + } + + // var metadata = generateResponseMetadata(apiEmbeddingResponse.id(), + // apiEmbeddingResponse.object(), + // apiEmbeddingResponse.created(), apiEmbeddingResponse.usage()); + + var metadata = new EmbeddingResponseMetadata(apiRequest.model(), + WenxinUsage.from(apiEmbeddingResponse.usage())); + + List embeddings = apiEmbeddingResponse.data() + .stream() + .map(e -> new Embedding(e.embedding(), e.index())) + .toList(); + + return new EmbeddingResponse(embeddings, metadata); + }); + } + +} diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinEmbeddingOptions.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinEmbeddingOptions.java new file mode 100644 index 00000000000..88700ece4e8 --- /dev/null +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinEmbeddingOptions.java @@ -0,0 +1,65 @@ +package org.springframework.ai.wenxin; + +import com.fasterxml.jackson.annotation.JsonProperty; +import org.springframework.ai.embedding.EmbeddingOptions; + +/** + * @author lvchzh + * @since 1.0.0 + */ +public class WenxinEmbeddingOptions implements EmbeddingOptions { + + private @JsonProperty("model") String model; + + private @JsonProperty("user_id") String userId; + + public static Builder builder() { + return new Builder(); + } + + public String getModel() { + return this.model; + } + + @Override + public Integer getDimensions() { + return 0; + } + + public void setModel(String model) { + this.model = model; + } + + public String getUserId() { + return this.userId; + } + + public void setUserId(String userId) { + this.userId = userId; + } + + public static class Builder { + + protected WenxinEmbeddingOptions options; + + public Builder() { + this.options = new WenxinEmbeddingOptions(); + } + + public Builder withModel(String model) { + this.options.setModel(model); + return this; + } + + public Builder withUserId(String userId) { + this.options.setUserId(userId); + return this; + } + + public WenxinEmbeddingOptions build() { + return this.options; + } + + } + +} diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/aot/WenxinRuntimeHints.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/aot/WenxinRuntimeHints.java new file mode 100644 index 00000000000..1fa1f07251e --- /dev/null +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/aot/WenxinRuntimeHints.java @@ -0,0 +1,26 @@ +package org.springframework.ai.wenxin.aot; + +import org.springframework.ai.wenxin.api.WenxinApi; +import org.springframework.aot.hint.MemberCategory; +import org.springframework.aot.hint.RuntimeHints; +import org.springframework.aot.hint.RuntimeHintsRegistrar; +import org.springframework.lang.NonNull; +import org.springframework.lang.Nullable; + +import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage; + +/** + * @author lvchzh + * @since 1.0.0 + */ +public class WenxinRuntimeHints implements RuntimeHintsRegistrar { + + @Override + public void registerHints(@NonNull RuntimeHints hints, @Nullable ClassLoader classLoader) { + var mcs = MemberCategory.values(); + for (var tr : findJsonAnnotatedClassesInPackage(WenxinApi.class)) { + hints.reflection().registerType(tr, mcs); + } + } + +} diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/api/ApiUtils.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/api/ApiUtils.java new file mode 100644 index 00000000000..9b9316f3f80 --- /dev/null +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/api/ApiUtils.java @@ -0,0 +1,102 @@ +package org.springframework.ai.wenxin.api; + +import org.apache.commons.codec.binary.Hex; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; + +import javax.crypto.Mac; +import javax.crypto.spec.SecretKeySpec; +import java.nio.charset.StandardCharsets; +import java.security.InvalidKeyException; +import java.security.NoSuchAlgorithmException; +import java.time.Instant; +import java.time.ZoneOffset; +import java.time.format.DateTimeFormatter; +import java.util.Optional; +import java.util.function.Consumer; + +/** + * @author lvchzh + * @since 1.0.0 + * @description: ApiUtils + */ +public class ApiUtils { + + // @formatter:off + public static final String DEFAULT_BASE_URL = "https://aip.baidubce.com"; + + public static final String DEFAULT_BASE_CHAT_URI = "/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/"; + + public static final String DEFAULT_BASE_EMBEDDING_URI = "/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/"; + + public static final String DEFAULT_HOST = "aip.baidubce.com"; + + private static final String EXPIRATION_PERIOD_IN_SECONDS = "1800"; + + private static final String HMAC_SHA256 = "HmacSHA256"; + + private static final DateTimeFormatter alternateIso8601DateFormat = DateTimeFormatter.ofPattern( + "yyyy-MM-dd'T'HH:mm:ss'Z'").withZone(ZoneOffset.UTC); + + public static Consumer getJsonContentHeaders() { + return (headers) -> headers.setContentType(MediaType.APPLICATION_JSON); + } + + public static String generationSignature(String accessKey, String secretKey, Instant timestamp, String modelName, String uri) { + var canonicalRequest = createCanonicalRequest(uri, modelName); + var authStringPrefix = createAuthStringPrefix(accessKey, timestamp); + var signingKey = hmacSha256Hex(secretKey, authStringPrefix); + var signature = hmacSha256Hex(signingKey, canonicalRequest.toString()); + return new StringBuilder() + .append(authStringPrefix) + .append("/host/") + .append(signature) + .toString(); + } + + private static String createAuthStringPrefix(String accessKey, Instant timestamp) { + return new StringBuilder() + .append("bce-auth-v1/").append(accessKey) + .append("/") + .append(formatDate(timestamp)) + .append("/") + .append(EXPIRATION_PERIOD_IN_SECONDS) + .toString(); + } + + private static StringBuilder createCanonicalRequest(String uri, String modelName) { + return new StringBuilder() + .append("POST") + .append("\n") + .append(uri).append(modelName) + .append("\n\n") + .append("host:").append(DEFAULT_HOST); + } + + private static String hmacSha256Hex(String secretKey, String authStringPrefix) { + try { + var mac = Mac.getInstance(HMAC_SHA256); + mac.init(new SecretKeySpec(secretKey.getBytes(StandardCharsets.UTF_8), HMAC_SHA256)); + return new String(Hex.encodeHex(mac.doFinal(authStringPrefix.getBytes(StandardCharsets.UTF_8)))); + } catch (NoSuchAlgorithmException | InvalidKeyException e) { + throw new RuntimeException("Failed to generate HMAC-SHA256 signature", e); + } + } + + private static Optional formatAlternateIso8601Date(Instant instant) { + if (instant == null) { + return Optional.empty(); + } + return Optional.of(alternateIso8601DateFormat.format(instant)); + } + + public static String formatDate(Instant instant) { + return formatAlternateIso8601Date(instant).orElseThrow(() -> new RuntimeException("Failed to format date")); + } + + public static String generationAuthorization(String accessKey, String secretKey, Instant timestamp, String model, String uri) { + return generationSignature(accessKey, secretKey, timestamp, model, uri); + } + // @formatter:on + +} diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/api/WenxinApi.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/api/WenxinApi.java new file mode 100644 index 00000000000..c313beb4526 --- /dev/null +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/api/WenxinApi.java @@ -0,0 +1,455 @@ +package org.springframework.ai.wenxin.api; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.boot.context.properties.bind.ConstructorBinding; +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.http.ResponseEntity; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.web.client.ResponseErrorHandler; +import org.springframework.web.client.RestClient; +import org.springframework.web.reactive.function.client.WebClient; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import java.time.Instant; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Predicate; + +/** + * @author lvchzh + * @since 1.0.0 + */ +public class WenxinApi { + + // @formatter:off + public static final String DEFAULT_CHAT_MODEL = ChatModel.ERNIE_3_5_8K.getValue(); + + public static final String DEFAULT_EMBEDDING_MODEL = EmbeddingModel.Embedding_V1.getValue(); + + private static final Predicate SSE_DONE_PREDICATE = "[DONE]"::equals; + + private final RestClient restClient; + + private final WebClient webClient; + + private final String accessKey; + + private final String secretKey; + + private WenxinStreamFunctionCallingHelper chunkMerger = new WenxinStreamFunctionCallingHelper(); + + public WenxinApi(String accessKey, String secretKey) { + this(ApiUtils.DEFAULT_BASE_URL, accessKey, secretKey); + } + + public WenxinApi(String baseUrl, String accessKey, String secretKey) { + this(baseUrl, RestClient.builder(), WebClient.builder(), accessKey, secretKey); + } + + public WenxinApi(String baseUrl, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder, + String accessKey, String secretKey) { + this(baseUrl, restClientBuilder, webClientBuilder, + RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER, + accessKey, + secretKey); + } + + public WenxinApi(String baseUrl, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder, + ResponseErrorHandler responseErrorHandler, String accessKey, String secretKey) { + + this.restClient = restClientBuilder + .baseUrl(baseUrl) + .defaultHeaders(ApiUtils.getJsonContentHeaders()) + .defaultStatusHandler(responseErrorHandler) + .build(); + + this.webClient = webClientBuilder + .baseUrl(baseUrl) + .defaultHeaders(ApiUtils.getJsonContentHeaders()) + .build(); + + this.accessKey = accessKey; + this.secretKey = secretKey; + } + + public ResponseEntity chatCompletionEntity(ChatCompletionRequest chatRequest) { + Assert.notNull(chatRequest, "The request body can not be null."); + Assert.isTrue(!chatRequest.stream(), "Request must set the steam property to false."); + + var timestamp = Instant.now(); + var authorization = ApiUtils.generationAuthorization(accessKey, secretKey, timestamp, chatRequest.model(), + ApiUtils.DEFAULT_BASE_CHAT_URI); + + return this.restClient.post() + .uri(ApiUtils.DEFAULT_BASE_CHAT_URI + chatRequest.model()) + .headers(headers -> { + headers.set("x-bce-date", ApiUtils.formatDate(timestamp)); + headers.set("Authorization", authorization); + }) + .body(chatRequest) + .retrieve() + .toEntity(ChatCompletion.class); + } + + public Flux chatCompletionStream(ChatCompletionRequest chatRequest) { + Assert.notNull(chatRequest, "The request body can not be null."); + Assert.isTrue(chatRequest.stream(), "Request must set the steam property to true."); + + AtomicBoolean isInsideTool = new AtomicBoolean(false); + Instant timestamp = Instant.now(); + String authorization = ApiUtils.generationAuthorization(accessKey, secretKey, timestamp, chatRequest.model(), + ApiUtils.DEFAULT_BASE_CHAT_URI); + + return this.webClient.post() + .uri(ApiUtils.DEFAULT_BASE_CHAT_URI + chatRequest.model()) + .headers(headers -> { + headers.set("x-bce-date", ApiUtils.formatDate(timestamp)); + headers.set("Authorization", authorization); + }) + .body(Mono.just(chatRequest), ChatCompletionRequest.class) + .retrieve() + .bodyToFlux(String.class) + .takeUntil(SSE_DONE_PREDICATE) + .filter(SSE_DONE_PREDICATE.negate()) + .map(content -> ModelOptionsUtils.jsonToObject(content, ChatCompletionChunk.class)) + .map(chunk -> { + if (this.chunkMerger.isStreamingToolFunctionCall(chunk)) { + isInsideTool.set(true); + } + return chunk; + }) + .windowUntil(chunk -> { + if (isInsideTool.get() && this.chunkMerger.isStreamingToolFunctionCallFinish(chunk)) { + isInsideTool.set(false); + return true; + } + return !isInsideTool.get(); + }) + .concatMapIterable(window -> { + Mono monoChunk = window.reduce( + new ChatCompletionChunk(null, null, null, null, null, null, null, null, null, null, null, + null, null, null), + (previous, current) -> this.chunkMerger.merge(previous, current)); + return List.of(monoChunk); + }) + .flatMap(mono -> mono); + } + + public enum ChatModel { + + ERNIE_4_8K("completions_pro"), + ERNIE_4_8K_PREEMPTIVE("completions_pro_preemptive"), + ERNIE_4_8K_PREVIEW("ernie-4.0-8k-preview"), + ERNIE_4_8K_0329("ernie-4.0-8k-0329"), + ERNIE_4_8K_0104("ernie-4.0-8k-0104"), + ERNIE_3_5_8K("completions"), + ERNIE_3_5_8K_0205("ernie-3.5-8k-0205"), + ERNIE_3_5_8K_1222("ernie-3.5-8k-1222"), + ERNIE_3_5_4K_0205("ernie-3.5-4k-0205"), + ERNIE_3_5_8K_PREEMPTIVE("completions_preemptive"), + ERNIE_3_5_8K_Preview("ernie-3.5-8k-preview"), + ERNIE_3_5_8K_0329("ernie-3.5-8k-0329"); + + public final String value; + + ChatModel(String value) { + this.value = value; + } + + public String getValue() { + return value; + } + + } + + public enum Role { + + @JsonProperty("user") USER, + @JsonProperty("assistant") ASSISTANT, + @JsonProperty("function") FUNCTION + + } + + public enum ChatCompletionFinishReason { + + @JsonProperty("normal") NORMAL, + @JsonProperty("stop") STOP, + @JsonProperty("length") LENGTH, + @JsonProperty("content_filter") CONTENT_FILTER, + @JsonProperty("function_call") FUNCTION_CALL + + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + public record FunctionTool( + @JsonProperty("name") String name, + @JsonProperty("description") String description, + @JsonProperty("parameters") Map parameters, + @JsonProperty("responses") Map responses, + @JsonProperty("examples") List> examples) { + + @ConstructorBinding + public FunctionTool(String name, String description, String jsonSchemaForParameters, + List> examples) { + this(name, description, ModelOptionsUtils.jsonToMap(jsonSchemaForParameters), null, examples); + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + public record Example( + @JsonProperty("role") Role role, + @JsonProperty("content") String content, + @JsonProperty("name") String name, + @JsonProperty("function_call") FunctionCall functionCall) { + } + + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + public record FunctionCall( + @JsonProperty("name") String name, + @JsonProperty("arguments") String arguments, + @JsonProperty("thoughts") String thoughts) { + + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + public record ChatCompletionRequest( + @JsonProperty("messages") List messages, + @JsonProperty("model") String model, + @JsonProperty("penalty_score") Float penaltyScore, + @JsonProperty("max_output_tokens") Integer maxOutputTokens, + @JsonProperty("response_format") ResponseFormat responseFormat, + @JsonProperty("stop") List stop, + @JsonProperty("stream") Boolean stream, + @JsonProperty("temperature") Float temperature, + @JsonProperty("top_p") Float topP, + @JsonProperty("functions") List functions, + @JsonProperty("tool_choice") String toolChoice, + @JsonProperty("user_id") String userId, + @JsonProperty("system") String system, + @JsonProperty("disable_search") Boolean disableSearch, + @JsonProperty("enable_citation") Boolean enableCitation, + @JsonProperty("enable_trace") Boolean enableTrace) { + + public ChatCompletionRequest(List messages, String model, Float temperature) { + this(messages, model, null, null, null, null, false, temperature, null, null, null, null, null, false, + false, false); + } + + public ChatCompletionRequest(List messages, String model, Float temperature, + boolean stream) { + this(messages, model, null, null, null, null, stream, temperature, null, null, null, null, null, false, + false, false); + } + + public ChatCompletionRequest(List messages, String model, List tools, + String toolChoice, Boolean disableSearch) { + this(messages, model, null, null, null, null, false, 0.8f, null, tools, toolChoice, null, null, + disableSearch, false, false); + } + + public ChatCompletionRequest(List messages, Boolean stream) { + this(messages, DEFAULT_CHAT_MODEL, null, null, null, null, stream, null, null, null, null, null, null, + false, false, false); + } + + public enum ResponseFormat { + + @JsonProperty("text") TEXT, + @JsonProperty("json_object") JSON_OBJECT + + } + + public static class ToolChoiceBuilder { + + public static final String DEFAULT_TOOL_CHOICE = "auto"; + + public static final String NONE = "none"; + + public static String FUNCTION(String functionName) { + return ModelOptionsUtils.toJsonString( + Map.of( + "type", "function", + "function", + Map.of("name", functionName) + ) + ); + } + + } + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + public record ChatCompletionMessage( + @JsonProperty("content") String content, + @JsonProperty("role") Role role, + @JsonProperty("name") String name, + @JsonProperty("function_call") FunctionCall functionCall) { + + public ChatCompletionMessage(String content, Role role) { + this(content, role, null, null); + } + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + public record ChatCompletion( + @JsonProperty("id") String id, + @JsonProperty("object") String object, + @JsonProperty("created") Long created, + @JsonProperty("sentence_id") String sentenceId, + @JsonProperty("is_end") Boolean isEnd, + @JsonProperty("is_truncated") Boolean isTruncated, + @JsonProperty("finish_reason") ChatCompletionFinishReason finishReason, + @JsonProperty("search_info") SearchInfo searchInfo, + @JsonProperty("result") String result, + @JsonProperty("need_clear_history") Boolean needClearHistory, + @JsonProperty("flag") Integer flag, + @JsonProperty("ban_round") Integer banRound, + @JsonProperty("usage") Usage usage, + @JsonProperty("function_call") FunctionCall functionCall) { + + @JsonInclude(JsonInclude.Include.NON_NULL) + public record SearchInfo(@JsonProperty("search_results") List searchResults) { + + } + + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + public record SearchResult( + @JsonProperty("index") Integer index, + @JsonProperty("url") String url, + @JsonProperty("title") String title) { + + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + public record Usage( + @JsonProperty("prompt_tokens") Integer promptTokens, + @JsonProperty("completion_tokens") Integer completionTokens, + @JsonProperty("total_tokens") Integer totalTokens, + @JsonProperty("plugins") List plugins) { + + @JsonInclude(JsonInclude.Include.NON_NULL) + public record PluginUsage( + @JsonProperty("name") String name, + @JsonProperty("parse_tokens") Integer parseTokens, + @JsonProperty("abstract_tokens") Integer abstractTokens, + @JsonProperty("search_tokens") Integer searchTokens, + @JsonProperty("total_tokens") Integer totalTokens) { + + } + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + public record ChatCompletionChunk( + @JsonProperty("id") String id, + @JsonProperty("object") String object, + @JsonProperty("created") Long created, + @JsonProperty("sentence_id") String sentenceId, + @JsonProperty("is_end") Boolean isEnd, + @JsonProperty("is_truncated") Boolean isTruncated, + @JsonProperty("finish_reason") ChatCompletionFinishReason finishReason, + @JsonProperty("search_info") ChatCompletion.SearchInfo searchInfo, + @JsonProperty("result") String result, + @JsonProperty("need_clear_history") Boolean needClearHistory, + @JsonProperty("flag") Integer flag, + @JsonProperty("ban_round") Integer banRound, + @JsonProperty("usage") Usage usage, + @JsonProperty("function_call") FunctionCall functionCall) { + + } + + // Embedding API + public ResponseEntity> embeddings(EmbeddingRequest embeddingRequest) { + + Assert.notNull(embeddingRequest, "The request body can not be null."); + Assert.notNull(embeddingRequest.input(), "The input can not be null."); + Assert.isTrue(embeddingRequest.input() instanceof String || embeddingRequest.input() instanceof List, + "The input must be either a String, or a List of Strings or List of List of integers."); + + if (embeddingRequest.input() instanceof List list) { + Assert.isTrue(!CollectionUtils.isEmpty(list), "The input list can not be empty."); + Assert.isTrue(list.size() <= 2048, "The list must be dimensions or less"); + Assert.isTrue(list.get(0) instanceof String || list.get(0) instanceof Integer + || list.get(0) instanceof List, + "The input must be either a String, or a list of Strings or list of integers."); + } + + Instant timestamp = Instant.now(); + String authorization = ApiUtils.generationAuthorization(accessKey, secretKey, timestamp, + embeddingRequest.model(), ApiUtils.DEFAULT_BASE_EMBEDDING_URI); + + return this.restClient.post() + .uri(ApiUtils.DEFAULT_BASE_EMBEDDING_URI + embeddingRequest.model()) + .headers(headers -> { + headers.set("x-bce-date", ApiUtils.formatDate(timestamp)); + headers.set("Authorization", authorization); + }) + .body(embeddingRequest) + .retrieve() + .toEntity(new ParameterizedTypeReference<>() { + }); + } + + public enum EmbeddingModel { + Embedding_V1("embedding-v1"), + BGE_LARGE_ZH("bge_large_zh"), + BGE_LARGE_EN("bge_large_en"), + TAO_8K("tao_8k"); + + public final String value; + + EmbeddingModel(String value) { + this.value = value; + } + + public String getValue() { + return value; + } + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + public record Embedding( + @JsonProperty("index") Integer index, + @JsonProperty("embedding")float[] embedding, + @JsonProperty("object") String object) { + + public Embedding(Integer index, float[] embedding) { + this(index, embedding, "embedding"); + } + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + public record EmbeddingRequest( + @JsonProperty("input") T input, + @JsonProperty("model") String model, + @JsonProperty("user_id") String userId) { + + public EmbeddingRequest(T input, String model) { + this(input, model, null); + } + + public EmbeddingRequest(T input) { + this(input, DEFAULT_EMBEDDING_MODEL); + } + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + public record EmbeddingList( + @JsonProperty("id") String id, + @JsonProperty("object") String object, + @JsonProperty("created") Long created, + @JsonProperty("data") List data, + @JsonProperty("usage") Usage usage) { + + } + // @formatter:on + +} diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/api/WenxinStreamFunctionCallingHelper.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/api/WenxinStreamFunctionCallingHelper.java new file mode 100644 index 00000000000..6fa8dfb8ce8 --- /dev/null +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/api/WenxinStreamFunctionCallingHelper.java @@ -0,0 +1,189 @@ +package org.springframework.ai.wenxin.api; + +import org.springframework.ai.wenxin.api.WenxinApi.ChatCompletionChunk; + +import java.util.ArrayList; +import java.util.List; + +/** + * @author lvchzh + * @since 1.0.0 + */ +public class WenxinStreamFunctionCallingHelper { + + public ChatCompletionChunk merge(ChatCompletionChunk previous, ChatCompletionChunk current) { + if (previous == null) { + return current; + } + + String id = (current.id() != null ? current.id() : previous.id()); + String object = (current.object() != null ? current.object() : previous.object()); + Long created = (current.created() != null ? current.created() : previous.created()); + String sentenceId = (current.sentenceId() != null ? current.sentenceId() : previous.sentenceId()); + Boolean isEnd = (current.isEnd() != null ? current.isEnd() : previous.isEnd()); + Boolean isTruncated = (current.isTruncated() != null ? current.isTruncated() : previous.isTruncated()); + Boolean needClearHistory = (current.needClearHistory() != null ? current.needClearHistory() + : previous.needClearHistory()); + String result = (current.result() != null ? current.result() : previous.result()); + Integer flag = (current.flag() != null ? current.flag() : previous.flag()); + Integer banRound = (current.banRound() != null ? current.banRound() : previous.banRound()); + + WenxinApi.ChatCompletionFinishReason finishReason = (current.finishReason() != null ? current.finishReason() + : previous.finishReason()); + + WenxinApi.ChatCompletion.SearchInfo searchInfo = merge(previous.searchInfo(), current.searchInfo()); + + WenxinApi.FunctionCall functionCall = merge(previous.functionCall(), current.functionCall()); + + WenxinApi.Usage usage = merge(previous.usage(), current.usage()); + + return new ChatCompletionChunk(id, object, created, sentenceId, isEnd, isTruncated, finishReason, searchInfo, + result, needClearHistory, flag, banRound, usage, functionCall); + + } + + private WenxinApi.ChatCompletion.SearchInfo merge(WenxinApi.ChatCompletion.SearchInfo previous, + WenxinApi.ChatCompletion.SearchInfo current) { + if (previous == null) { + return current; + } + + List searchResults = new ArrayList<>(); + WenxinApi.SearchResult lastPreviousSearchResult = null; + if (previous.searchResults() != null) { + lastPreviousSearchResult = previous.searchResults().get(previous.searchResults().size() - 1); + if (previous.searchResults() != null) { + searchResults.addAll(previous.searchResults().subList(0, previous.searchResults().size() - 1)); + } + } + if (current.searchResults() != null) { + if (current.searchResults().size() > 1) { + throw new IllegalArgumentException("Currently only one tool call is supported per message!"); + } + var currentSearchResult = current.searchResults().iterator().next(); + if (currentSearchResult.index() != null) { + if (lastPreviousSearchResult != null) { + searchResults.add(lastPreviousSearchResult); + } + searchResults.add(currentSearchResult); + } + else { + searchResults.add(merge(lastPreviousSearchResult, currentSearchResult)); + } + } + else { + if (lastPreviousSearchResult != null) { + searchResults.add(lastPreviousSearchResult); + } + } + return new WenxinApi.ChatCompletion.SearchInfo(searchResults); + } + + private WenxinApi.SearchResult merge(WenxinApi.SearchResult previous, WenxinApi.SearchResult current) { + if (previous != null) { + return current; + } + + Integer id = current.index() != null ? current.index() : previous.index(); + String title = current.title() != null ? current.title() : previous.title(); + String url = current.url() != null ? current.url() : previous.url(); + + return new WenxinApi.SearchResult(id, title, url); + } + + private WenxinApi.FunctionCall merge(WenxinApi.FunctionCall previous, WenxinApi.FunctionCall current) { + if (previous == null) { + return current; + } + + String name = current.name() != null ? current.name() : previous.name(); + String thoughts = current.thoughts() != null ? current.thoughts() : previous.thoughts(); + StringBuilder arguments = new StringBuilder(); + if (previous.arguments() != null) { + arguments.append(previous.arguments()); + } + if (current.arguments() != null) { + arguments.append(current.arguments()); + } + + return new WenxinApi.FunctionCall(name, arguments.toString(), thoughts); + + } + + private WenxinApi.Usage merge(WenxinApi.Usage previous, WenxinApi.Usage current) { + if (previous == null) { + return current; + } + + Integer promptTokens = current.promptTokens() != null ? current.promptTokens() : previous.promptTokens(); + Integer completionTokens = current.completionTokens() != null ? current.completionTokens() + : previous.completionTokens(); + Integer totalTokens = current.totalTokens() != null ? current.totalTokens() : previous.totalTokens(); + + List plugins = new ArrayList<>(); + WenxinApi.Usage.PluginUsage lastPreviousPluginUsage = null; + if (previous.plugins() != null) { + lastPreviousPluginUsage = previous.plugins().get(previous.plugins().size() - 1); + if (previous.plugins().size() > 1) { + plugins.addAll(previous.plugins().subList(0, previous.plugins().size() - 1)); + } + } + if (current.plugins() != null) { + if (current.plugins().size() > 1) { + throw new IllegalArgumentException("Currently only one tool call is supported per message!"); + } + var currentPluginUsage = current.plugins().iterator().next(); + if (currentPluginUsage.name() != null) { + if (lastPreviousPluginUsage != null) { + plugins.add(lastPreviousPluginUsage); + } + plugins.add(currentPluginUsage); + } + else { + plugins.add(merge(lastPreviousPluginUsage, currentPluginUsage)); + } + } + else { + if (lastPreviousPluginUsage != null) { + plugins.add(lastPreviousPluginUsage); + } + } + return new WenxinApi.Usage(promptTokens, completionTokens, totalTokens, plugins); + } + + private WenxinApi.Usage.PluginUsage merge(WenxinApi.Usage.PluginUsage previous, + WenxinApi.Usage.PluginUsage current) { + if (previous == null) { + return current; + } + + String name = current.name() != null ? current.name() : previous.name(); + Integer parseTokens = current.parseTokens() != null ? current.parseTokens() : previous.parseTokens(); + Integer abstractTokens = current.abstractTokens() != null ? current.abstractTokens() + : previous.abstractTokens(); + Integer searchTokens = current.searchTokens() != null ? current.searchTokens() : previous.searchTokens(); + Integer totalTokens = current.totalTokens() != null ? current.totalTokens() : previous.totalTokens(); + + return new WenxinApi.Usage.PluginUsage(name, parseTokens, abstractTokens, searchTokens, totalTokens); + } + + public boolean isStreamingToolFunctionCall(ChatCompletionChunk chatCompletion) { + + if (chatCompletion == null || chatCompletion.functionCall() == null) { + return false; + } + + return true; + } + + public boolean isStreamingToolFunctionCallFinish(ChatCompletionChunk chatCompletion) { + + if (chatCompletion == null || chatCompletion.functionCall() == null) { + return false; + } + + return chatCompletion.finishReason() == WenxinApi.ChatCompletionFinishReason.FUNCTION_CALL; + + } + +} diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/api/common/WenxinApiClientErrorException.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/api/common/WenxinApiClientErrorException.java new file mode 100644 index 00000000000..3432064dda2 --- /dev/null +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/api/common/WenxinApiClientErrorException.java @@ -0,0 +1,17 @@ +package org.springframework.ai.wenxin.api.common; + +/** + * @author lvchzh + * @since 1.0.0 + */ +public class WenxinApiClientErrorException extends RuntimeException { + + public WenxinApiClientErrorException(String message) { + super(message); + } + + public WenxinApiClientErrorException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/api/common/WenxinApiException.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/api/common/WenxinApiException.java new file mode 100644 index 00000000000..54a859397da --- /dev/null +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/api/common/WenxinApiException.java @@ -0,0 +1,17 @@ +package org.springframework.ai.wenxin.api.common; + +/** + * @author lvchzh + * @since 1.0.0 + */ +public class WenxinApiException extends RuntimeException { + + public WenxinApiException(String message) { + super(message); + } + + public WenxinApiException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/metadata/WenxinRateLimit.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/metadata/WenxinRateLimit.java new file mode 100644 index 00000000000..6f0d47167f7 --- /dev/null +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/metadata/WenxinRateLimit.java @@ -0,0 +1,68 @@ +package org.springframework.ai.wenxin.metadata; + +import org.springframework.ai.chat.metadata.RateLimit; + +import java.time.Duration; + +/** + * @author lvchzh + * @since 1.0.0 + */ +public class WenxinRateLimit implements RateLimit { + + // @formatter:off + private static final String RATE_LIMIT_STRING = "{ @type: %1$s, requestsLimit: %2$s, requestsRemaining: %3$s, tokensLimit: %4$s, tokensRemaining: %5$s }"; + + private final Long requestsLimit; + + private final Long requestsRemaining; + + private final Long tokensLimit; + + private final Long tokensRemaining; + + public WenxinRateLimit(Long requestsLimit, Long requestsRemaining, Long tokensLimit, Long tokensRemaining) { + this.requestsLimit = requestsLimit; + this.requestsRemaining = requestsRemaining; + this.tokensLimit = tokensLimit; + this.tokensRemaining = tokensRemaining; + } + + @Override + public Long getRequestsLimit() { + return this.requestsLimit; + } + + @Override + public Long getRequestsRemaining() { + return this.requestsRemaining; + } + + @Override + public Duration getRequestsReset() { + throw new UnsupportedOperationException("unimplemented method 'getRequestsReset'"); + } + + @Override + public Long getTokensLimit() { + return this.tokensLimit; + } + + @Override + public Long getTokensRemaining() { + return this.tokensRemaining; + } + + @Override + public Duration getTokensReset() { + throw new UnsupportedOperationException("unimplemented method 'getTokensReset'"); + } + + @Override + public String toString() { + return RATE_LIMIT_STRING.formatted(getClass().getName(), getRequestsLimit(), getRequestsRemaining(), + getTokensLimit(), getTokensRemaining()); + } + // @formatter:on + +} diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/metadata/WenxinUsage.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/metadata/WenxinUsage.java new file mode 100644 index 00000000000..05eca5027ba --- /dev/null +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/metadata/WenxinUsage.java @@ -0,0 +1,48 @@ +package org.springframework.ai.wenxin.metadata; + +import org.springframework.ai.chat.metadata.Usage; +import org.springframework.ai.wenxin.api.WenxinApi; +import org.springframework.util.Assert; + +/** + * @author lvchzh + * @since 1.0.0 + */ +public class WenxinUsage implements Usage { + + private final WenxinApi.Usage usage; + + protected WenxinUsage(WenxinApi.Usage usage) { + Assert.notNull(usage, "Wenxin Usage must not be null"); + this.usage = usage; + } + + public static WenxinUsage from(WenxinApi.Usage usage) { + return new WenxinUsage(usage); + } + + protected WenxinApi.Usage getUsage() { + return this.usage; + } + + @Override + public Long getPromptTokens() { + return getUsage().promptTokens().longValue(); + } + + @Override + public Long getGenerationTokens() { + return getUsage().completionTokens().longValue(); + } + + @Override + public Long getTotalTokens() { + return getUsage().totalTokens().longValue(); + } + + @Override + public String toString() { + return getUsage().toString(); + } + +} diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/metadata/support/WenxinApiResponseHeaders.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/metadata/support/WenxinApiResponseHeaders.java new file mode 100644 index 00000000000..96e7679b5d6 --- /dev/null +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/metadata/support/WenxinApiResponseHeaders.java @@ -0,0 +1,33 @@ +package org.springframework.ai.wenxin.metadata.support; + +/** + * @author lvchzh + * @since 1.0.0 + */ +public enum WenxinApiResponseHeaders { + + // @formatter:off + REQUESTS_LIMIT_HEADER("X-Ratelimit-Limit-Requests", "Total number of requests allowed within timeframe."), + TOKENS_LIMIT_HEADER("X-Ratelimit-Limit-Tokens", "Remaining number of tokens available in timeframe."), + REQUESTS_REMAINING_HEADER("X-Ratelimit-Remaining-Requests", "Remaining number of requests available in timeframe."), + TOKENS_REMAINING_HEADER("X-Ratelimit-Remaining-Tokens", "Duration of time until the number of tokens reset."); + // @formatter:on + + private String headerName; + + private String description; + + WenxinApiResponseHeaders(String headerName, String description) { + this.headerName = headerName; + this.description = description; + } + + public String getName() { + return this.headerName; + } + + public String getDescription() { + return this.description; + } + +} diff --git a/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/metadata/support/WenxinResponseHeaderExtractor.java b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/metadata/support/WenxinResponseHeaderExtractor.java new file mode 100644 index 00000000000..c482e047efc --- /dev/null +++ b/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/metadata/support/WenxinResponseHeaderExtractor.java @@ -0,0 +1,59 @@ +package org.springframework.ai.wenxin.metadata.support; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.metadata.RateLimit; +import org.springframework.ai.wenxin.metadata.WenxinRateLimit; +import org.springframework.http.ResponseEntity; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; + +import static org.springframework.ai.wenxin.metadata.support.WenxinApiResponseHeaders.REQUESTS_LIMIT_HEADER; +import static org.springframework.ai.wenxin.metadata.support.WenxinApiResponseHeaders.REQUESTS_REMAINING_HEADER; +import static org.springframework.ai.wenxin.metadata.support.WenxinApiResponseHeaders.TOKENS_LIMIT_HEADER; +import static org.springframework.ai.wenxin.metadata.support.WenxinApiResponseHeaders.TOKENS_REMAINING_HEADER; + +/** + * @author lvchzh + * @since 1.0.0 + */ +public class WenxinResponseHeaderExtractor { + + private static final Logger logger = LoggerFactory.getLogger(WenxinResponseHeaderExtractor.class); + + public static RateLimit extractAiResponseHeaders(ResponseEntity response) { + + Long requestsLimit = getHeaderAsLong(response, REQUESTS_LIMIT_HEADER.getName()); + Long requestRemaining = getHeaderAsLong(response, REQUESTS_REMAINING_HEADER.getName()); + Long tokensLimit = getHeaderAsLong(response, TOKENS_LIMIT_HEADER.getName()); + Long tokensRemaining = getHeaderAsLong(response, TOKENS_REMAINING_HEADER.getName()); + + return new WenxinRateLimit(requestsLimit, requestRemaining, tokensLimit, tokensRemaining); + } + + private static Long getHeaderAsLong(ResponseEntity response, String headerName) { + var headers = response.getHeaders(); + if (headers.containsKey(headerName)) { + var values = headers.get(headerName); + if (!CollectionUtils.isEmpty(values)) { + return parseLong(headerName, values.get(0)); + } + } + return null; + } + + private static Long parseLong(String headerName, String headerValue) { + + if (StringUtils.hasText(headerValue)) { + try { + return Long.valueOf(headerValue); + } + catch (NumberFormatException e) { + logger.warn("Value [{}] for HTTP header [{}] is not valid: {}", headerName, headerValue, + e.getMessage()); + } + } + return null; + } + +} diff --git a/models/spring-ai-wenxin/src/main/resources/META-INF/spring/aot.factories b/models/spring-ai-wenxin/src/main/resources/META-INF/spring/aot.factories new file mode 100644 index 00000000000..f6a8c79b1ca --- /dev/null +++ b/models/spring-ai-wenxin/src/main/resources/META-INF/spring/aot.factories @@ -0,0 +1,2 @@ +org.springframework.aot.hint.RuntimeHintsRegistrar=\ + org.springframework.ai.wenxin.aot.WenxinRuntimeHints \ No newline at end of file diff --git a/models/spring-ai-wenxin/src/test/java/org/springframework/ai/wenxin/WenxinTestConfiguration.java b/models/spring-ai-wenxin/src/test/java/org/springframework/ai/wenxin/WenxinTestConfiguration.java new file mode 100644 index 00000000000..fd3f2e6d5bd --- /dev/null +++ b/models/spring-ai-wenxin/src/test/java/org/springframework/ai/wenxin/WenxinTestConfiguration.java @@ -0,0 +1,53 @@ +package org.springframework.ai.wenxin; + +import org.springframework.ai.document.MetadataMode; +import org.springframework.ai.wenxin.api.WenxinApi; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.context.annotation.Bean; +import org.springframework.util.StringUtils; + +import static org.springframework.ai.wenxin.api.WenxinApi.EmbeddingModel.TAO_8K; + +@SpringBootConfiguration +public class WenxinTestConfiguration { + + @Bean + public WenxinApi wenxinApi() { + return new WenxinApi(getAccessKey(), getSecretKey()); + } + + private String getSecretKey() { + String secretKey = System.getenv("WENXIN_SECRET_KEY"); + if (!StringUtils.hasText(secretKey)) { + throw new IllegalArgumentException( + "You must provide an Secret Key. Put it in an environment variable under the name " + + "WENXIN_SECRET_KEY"); + } + return secretKey; + } + + private String getAccessKey() { + String accessKey = System.getenv("WENXIN_ACCESS_KEY"); + if (!StringUtils.hasText(accessKey)) { + throw new IllegalArgumentException( + "You must provide an Access Key. Put it in an environment variable under the name " + + "WENXIN_ACCESS_KEY"); + } + return accessKey; + + } + + @Bean + public WenxinChatModel wenxinChatModel(WenxinApi api) { + WenxinChatModel wenxinChatModel = new WenxinChatModel(api); + return wenxinChatModel; + } + + @Bean + WenxinEmbeddingModel wenxinEmbeddingModel(WenxinApi api) { + WenxinEmbeddingModel wenxinEmbeddingModel = new WenxinEmbeddingModel(api, MetadataMode.EMBED, + WenxinEmbeddingOptions.builder().withModel(TAO_8K.value).build()); + return wenxinEmbeddingModel; + } + +} diff --git a/models/spring-ai-wenxin/src/test/java/org/springframework/ai/wenxin/acme/AcmeIT.java b/models/spring-ai-wenxin/src/test/java/org/springframework/ai/wenxin/acme/AcmeIT.java new file mode 100644 index 00000000000..5a9e92783a9 --- /dev/null +++ b/models/spring-ai-wenxin/src/test/java/org/springframework/ai/wenxin/acme/AcmeIT.java @@ -0,0 +1,100 @@ +package org.springframework.ai.wenxin.acme; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.AssistantPromptTemplate; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.document.Document; +import org.springframework.ai.reader.JsonReader; +import org.springframework.ai.transformer.splitter.TokenTextSplitter; +import org.springframework.ai.vectorstore.SimpleVectorStore; +import org.springframework.ai.vectorstore.VectorStore; +import org.springframework.ai.wenxin.WenxinChatModel; +import org.springframework.ai.wenxin.WenxinEmbeddingModel; +import org.springframework.ai.wenxin.WenxinTestConfiguration; +import org.springframework.ai.wenxin.testutils.AbstractIT; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.core.io.Resource; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import static org.assertj.core.api.Assertions.assertThat; + +@SpringBootTest(classes = WenxinTestConfiguration.class) +@EnabledIfEnvironmentVariable(named = "WENXIN_ACCESS_KEY", matches = ".+") +@EnabledIfEnvironmentVariable(named = "WENXIN_SECRET_KEY", matches = ".+") +public class AcmeIT extends AbstractIT { + + private static final Logger logger = LoggerFactory.getLogger(AcmeIT.class); + + @Value("classpath:/data/acme/bikes.json") + private Resource bikesResource; + + @Value("classpath:/prompts/acme/system-qa.st") + private Resource systemBikePrompt; + + @Autowired + private WenxinEmbeddingModel embeddingModel; + + @Autowired + private WenxinChatModel chatModel; + + @Test + void beanTest() { + assertThat(bikesResource).isNotNull(); + assertThat(embeddingModel).isNotNull(); + assertThat(chatModel).isNotNull(); + } + + @Test + void acmeChain() throws IOException { + JsonReader jsonReader = new JsonReader(bikesResource, "name", "price", "shortDescription", "description"); + + var testSplitter = new TokenTextSplitter(); + + logger.info("Creating Embeddings..."); + + VectorStore vectorStore = new SimpleVectorStore(embeddingModel); + vectorStore.accept(testSplitter.apply(jsonReader.get())); + + logger.info("Retrieving relevant documents"); + String userQuery = "What bike is good for city commuting?"; + + List similarDocuments = vectorStore.similaritySearch(userQuery); + + logger.info(String.format("Found %s relevant documents.", similarDocuments.size())); + + Message assistantMessage = getAssistantMessage(similarDocuments); + UserMessage userMessage = new UserMessage(userQuery); + + logger.info("Asking AI generative to reply to question."); + Prompt prompt = new Prompt(List.of(userMessage, assistantMessage, userMessage)); + logger.info("AI responded."); + ChatResponse response = chatModel.call(prompt); + + evaluateQuestionAndAnswer(userQuery, response, true); + } + + private Message getAssistantMessage(List similarDocuments) { + + String documents = similarDocuments.stream() + .map(entry -> entry.getContent()) + .collect(Collectors.joining(System.lineSeparator())); + + AssistantPromptTemplate systemPromptTemplate = new AssistantPromptTemplate(systemBikePrompt); + Message systemMessage = systemPromptTemplate.createMessage(Map.of("documents", documents)); + return systemMessage; + + } + +} diff --git a/models/spring-ai-wenxin/src/test/java/org/springframework/ai/wenxin/aot/WenxinRuntimeHintsIT.java b/models/spring-ai-wenxin/src/test/java/org/springframework/ai/wenxin/aot/WenxinRuntimeHintsIT.java new file mode 100644 index 00000000000..c5feedfe31d --- /dev/null +++ b/models/spring-ai-wenxin/src/test/java/org/springframework/ai/wenxin/aot/WenxinRuntimeHintsIT.java @@ -0,0 +1,28 @@ +package org.springframework.ai.wenxin.aot; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.wenxin.api.WenxinApi; +import org.springframework.aot.hint.RuntimeHints; +import org.springframework.aot.hint.TypeReference; + +import java.util.Set; + +import static org.assertj.core.api.Java6Assertions.assertThat; +import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage; +import static org.springframework.aot.hint.predicate.RuntimeHintsPredicates.reflection; + +public class WenxinRuntimeHintsIT { + + @Test + void registerHints() { + RuntimeHints runtimeHints = new RuntimeHints(); + WenxinRuntimeHints wenxinRuntimeHints = new WenxinRuntimeHints(); + wenxinRuntimeHints.registerHints(runtimeHints, null); + + Set jsonAnnotatedClasses = findJsonAnnotatedClassesInPackage(WenxinApi.class); + for (TypeReference jsonAnnotatedClass : jsonAnnotatedClasses) { + assertThat(runtimeHints).matches(reflection().onType(jsonAnnotatedClass)); + } + } + +} \ No newline at end of file diff --git a/models/spring-ai-wenxin/src/test/java/org/springframework/ai/wenxin/api/WenxinApiIT.java b/models/spring-ai-wenxin/src/test/java/org/springframework/ai/wenxin/api/WenxinApiIT.java new file mode 100644 index 00000000000..e749d59af60 --- /dev/null +++ b/models/spring-ai-wenxin/src/test/java/org/springframework/ai/wenxin/api/WenxinApiIT.java @@ -0,0 +1,54 @@ +package org.springframework.ai.wenxin.api; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.http.ResponseEntity; +import reactor.core.publisher.Flux; + +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author lvchzh + * @since 1.0.0 + */ +@EnabledIfEnvironmentVariable(named = "WENXIN_ACCESS_KEY", matches = ".+") +@EnabledIfEnvironmentVariable(named = "WENXIN_SECRET_KEY", matches = ".+") +public class WenxinApiIT { + + WenxinApi wenxinApi = new WenxinApi(System.getenv("WENXIN_ACCESS_KEY"), System.getenv("WENXIN_SECRET_KEY")); + + @Test + void chatCompletionEntity() { + WenxinApi.ChatCompletionMessage chatCompletionMessage = new WenxinApi.ChatCompletionMessage("Tell me a joke", + WenxinApi.Role.USER); + ResponseEntity response = wenxinApi.chatCompletionEntity( + new WenxinApi.ChatCompletionRequest(List.of(chatCompletionMessage), "completions", 0.8f, false)); + System.out.println(response.getBody().result()); + assertThat(response).isNotNull(); + assertThat(response.getBody()).isNotNull(); + } + + @Test + void chatCompletionStream() { + WenxinApi.ChatCompletionMessage chatCompletionMessage = new WenxinApi.ChatCompletionMessage("Tell me a joke", + WenxinApi.Role.USER); + Flux response = wenxinApi.chatCompletionStream( + new WenxinApi.ChatCompletionRequest(List.of(chatCompletionMessage), "completions", 0.8f, true)); + + assertThat(response).isNotNull(); + assertThat(response.collectList().block()).isNotNull(); + } + + @Test + void embeddings() { + ResponseEntity> response = wenxinApi + .embeddings(new WenxinApi.EmbeddingRequest<>(List.of("Hello world"))); + + assertThat(response).isNotNull(); + assertThat(response.getBody().data()).hasSize(1); + assertThat(response.getBody().data().get(0).embedding()).hasSize(384); + } + +} diff --git a/models/spring-ai-wenxin/src/test/java/org/springframework/ai/wenxin/api/tool/MockWeatherService.java b/models/spring-ai-wenxin/src/test/java/org/springframework/ai/wenxin/api/tool/MockWeatherService.java new file mode 100644 index 00000000000..01ed927ee15 --- /dev/null +++ b/models/spring-ai-wenxin/src/test/java/org/springframework/ai/wenxin/api/tool/MockWeatherService.java @@ -0,0 +1,55 @@ +package org.springframework.ai.wenxin.api.tool; + +import com.fasterxml.jackson.annotation.JsonClassDescription; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonPropertyDescription; + +import java.util.List; +import java.util.Map; +import java.util.function.Function; + +public class MockWeatherService implements Function { + + @Override + public Response apply(Request request) { + return new Response( + Map.of("San Francisco", new CityResp(20, 18, 15, 25, 1013, 50, request.cityInfos.get(0).unit()), + "Tokyo", new CityResp(25, 23, 20, 30, 1013, 50, request.cityInfos.get(0).unit()), "Paris", + new CityResp(15, 13, 10, 20, 1013, 50, request.cityInfos.get(0).unit()))); + } + + public enum Unit { + + C("metric"), F("imperial"); + + public final String unitName; + + Unit(String text) { + this.unitName = text; + } + + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + @JsonClassDescription("Weather API request") + public record Request(@JsonProperty(required = true, value = "city_infos") List cityInfos) { + + public record CityInfo(@JsonProperty(required = true, + value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, + @JsonProperty(required = true, value = "lat") @JsonPropertyDescription("The city latitude") double lat, + @JsonProperty(required = true, value = "lon") @JsonPropertyDescription("The city longitude") double lon, + @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + + } + } + + public record CityResp(double temp, double feels_like, double temp_min, double temp_max, int pressure, int humidity, + Unit unit) { + + } + + public record Response(Map cityRests) { + } + +} diff --git a/models/spring-ai-wenxin/src/test/java/org/springframework/ai/wenxin/api/tool/WenxinApiToolFunctionCallIT.java b/models/spring-ai-wenxin/src/test/java/org/springframework/ai/wenxin/api/tool/WenxinApiToolFunctionCallIT.java new file mode 100644 index 00000000000..1b801ddd4d1 --- /dev/null +++ b/models/spring-ai-wenxin/src/test/java/org/springframework/ai/wenxin/api/tool/WenxinApiToolFunctionCallIT.java @@ -0,0 +1,160 @@ +package org.springframework.ai.wenxin.api.tool; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.jetbrains.annotations.NotNull; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.wenxin.api.WenxinApi; +import org.springframework.http.ResponseEntity; + +import java.util.ArrayList; +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; + +@EnabledIfEnvironmentVariable(named = "WENXIN_ACCESS_KEY", matches = ".+") +@EnabledIfEnvironmentVariable(named = "WENXIN_SECRET_KEY", matches = ".+") +public class WenxinApiToolFunctionCallIT { + + private final Logger logger = LoggerFactory.getLogger(WenxinApiToolFunctionCallIT.class); + + private MockWeatherService weatherService = new MockWeatherService(); + + private WenxinApi completionApi = new WenxinApi(System.getenv("WENXIN_ACCESS_KEY"), + System.getenv("WENXIN_SECRET_KEY")); + + private static T fromJson(String json, Class targetClass) { + try { + return new ObjectMapper().readValue(json, targetClass); + } + catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + private static @NotNull List> getLists() { + var example1 = new WenxinApi.FunctionTool.Example(WenxinApi.Role.USER, + "weather list about San Francisco and Tokyo and Paris?", null, null); + + var example2 = new WenxinApi.FunctionTool.Example(WenxinApi.Role.ASSISTANT, null, null, + new WenxinApi.FunctionCall("getCurrentWeather", """ + { + "city_infos": [ + { + "location": "San Francisco", + "lat": 37.7749, + "lon": -122.4194, + "unit": "C" + }, + { + "location": "Tokyo", + "lat": 35.6895, + "lon": 139.6917, + "unit": "C" + }, + { + "location": "Paris", + "lat": 48.8566, + "lon": 2.3522, + "unit": "C" + } + ] + } + """, null)); + List examples = new ArrayList<>(List.of(example1, example2)); + List> exampleList = new ArrayList<>(List.of(examples)); + return exampleList; + } + + @Test + public void tooFunctionCall() throws JsonProcessingException { + + var message = new WenxinApi.ChatCompletionMessage("weather list about San Francisco and Tokyo and Paris?", + WenxinApi.Role.USER); + + List> exampleList = getLists(); + var functionTool = new WenxinApi.FunctionTool("getCurrentWeather", + "Get the weather in location. Return temperature in Celsius.", ModelOptionsUtils.jsonToMap(""" + { + "type": "object", + "properties": { + "city_infos": { + "type": "array", + "description": "city_info", + "items": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city name" + }, + "lat": { + "type": "number", + "description": "The city latitude" + }, + "lon": { + "type": "number", + "description": "The city longitude" + }, + "unit": { + "type": "string", + "enum": ["C", "F"] + } + }, + "required": ["location", "lat", "lon", "unit"] + } + + } + } + } + """), null, exampleList); + + List messages = new ArrayList<>(List.of(message)); + + WenxinApi.ChatCompletionRequest chatCompletionRequest = new WenxinApi.ChatCompletionRequest(messages, + // "completions", List.of(functionTool), + // WenxinApi.ChatCompletionRequest.ToolChoiceBuilder.FUNCTION + // ("getCurrentWeather"), true); + "completions", List.of(functionTool), null, true); + ResponseEntity chatCompletion = completionApi + .chatCompletionEntity(chatCompletionRequest); + + assertThat(chatCompletion.getBody()).isNotNull(); + assertThat(chatCompletion.getBody().functionCall()).isNotNull(); + + // chatCompletion.getBody(). + WenxinApi.ChatCompletionMessage responseMessage = new WenxinApi.ChatCompletionMessage( + chatCompletion.getBody().result(), WenxinApi.Role.ASSISTANT, null, + chatCompletion.getBody().functionCall()); + + if (chatCompletion.getBody().finishReason().name().equals("FUNCTION_CALL")) { + messages.add(responseMessage); + + if ("getCurrentWeather".equals(responseMessage.functionCall().name())) { + MockWeatherService.Request weatherRequest = fromJson(responseMessage.functionCall().arguments(), + MockWeatherService.Request.class); + + MockWeatherService.Response weatherResponse = weatherService.apply(weatherRequest); + + messages.add(new WenxinApi.ChatCompletionMessage(new ObjectMapper().writeValueAsString(weatherResponse), + WenxinApi.Role.FUNCTION, "getCurrentWeather", null)); + } + } + + var functionResponseRequest = new WenxinApi.ChatCompletionRequest(messages, "completions", + List.of(functionTool), null, true); + + ResponseEntity functionResponse = completionApi + .chatCompletionEntity(functionResponseRequest); + + logger.info("Function response: {}", functionResponse.getBody().result()); + + assertThat(functionResponse.getBody().result()).isNotEmpty(); + + } + +} diff --git a/models/spring-ai-wenxin/src/test/java/org/springframework/ai/wenxin/embedding/WenxinEmbeddingIT.java b/models/spring-ai-wenxin/src/test/java/org/springframework/ai/wenxin/embedding/WenxinEmbeddingIT.java new file mode 100644 index 00000000000..27a7dc5f227 --- /dev/null +++ b/models/spring-ai-wenxin/src/test/java/org/springframework/ai/wenxin/embedding/WenxinEmbeddingIT.java @@ -0,0 +1,36 @@ +package org.springframework.ai.wenxin.embedding; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.ai.wenxin.WenxinEmbeddingModel; +import org.springframework.ai.wenxin.api.WenxinApi; + +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author lvchzh + * @since 1.0.0 + */ +@EnabledIfEnvironmentVariable(named = "WENXIN_ACCESS_KEY", matches = ".+") +@EnabledIfEnvironmentVariable(named = "WENXIN_SECRET_KEY", matches = ".+") +public class WenxinEmbeddingIT { + + WenxinEmbeddingModel wenxinEmbeddingModel = new WenxinEmbeddingModel( + new WenxinApi(System.getenv("WENXIN_ACCESS_KEY"), System.getenv("WENXIN_SECRET_KEY"))); + + @Test + void defaultEmbedding() { + assertThat(wenxinEmbeddingModel).isNotNull(); + EmbeddingResponse embeddingResponse = wenxinEmbeddingModel.embedForResponse(List.of("Hello World")); + System.out.println(embeddingResponse); + assertThat(embeddingResponse.getResults()).hasSize(1); + assertThat(embeddingResponse.getResults().get(0)).isNotNull(); + assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(384); + assertThat(wenxinEmbeddingModel.dimensions()).isEqualTo(384); + + } + +} diff --git a/models/spring-ai-wenxin/src/test/java/org/springframework/ai/wenxin/testutils/AbstractIT.java b/models/spring-ai-wenxin/src/test/java/org/springframework/ai/wenxin/testutils/AbstractIT.java new file mode 100644 index 00000000000..b6967e62d97 --- /dev/null +++ b/models/spring-ai-wenxin/src/test/java/org/springframework/ai/wenxin/testutils/AbstractIT.java @@ -0,0 +1,79 @@ +package org.springframework.ai.wenxin.testutils; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.core.io.Resource; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; + +public class AbstractIT { + + private static final Logger logger = LoggerFactory.getLogger(AbstractIT.class); + + @Autowired + protected ChatModel chatModel; + + @Value("classpath:/prompts/eval/qa-evaluator-accurate-answer.st") + protected Resource qaEvaluatorAccurateAnswerResource; + + @Value("classpath:/prompts/eval/qa-evaluator-not-related-message.st") + protected Resource qaEvaluatorNotRelatedResource; + + @Value("classpath:/prompts/eval/qa-evaluator-fact-based-answer.st") + protected Resource qaEvaluatorFactBasedAnswerResource; + + @Value("classpath:/prompts/eval/user-evaluator-message.st") + protected Resource userEvaluatorResource; + + protected void evaluateQuestionAndAnswer(String question, ChatResponse response, boolean factBased) + throws IOException { + assertThat(response).isNotNull(); + + String answer = response.getResult().getOutput().getContent(); + logger.info("Question: {}", question); + logger.info("Answer: {}", answer); + + PromptTemplate userPromptTemplate = new PromptTemplate(userEvaluatorResource, + Map.of("question", question, "answer", answer)); + AssistantMessage systemMessage; + if (factBased) { + systemMessage = new AssistantMessage( + qaEvaluatorFactBasedAnswerResource.getContentAsString(StandardCharsets.UTF_8)); + } + else { + systemMessage = new AssistantMessage( + qaEvaluatorAccurateAnswerResource.getContentAsString(StandardCharsets.UTF_8)); + } + Message userMessage = userPromptTemplate.createMessage(); + Prompt prompt = new Prompt(List.of(userMessage, systemMessage, userMessage)); + String yesOrNo = chatModel.call(prompt).getResult().getOutput().getContent(); + logger.info("Is Answer related to question: {}", yesOrNo); + if (yesOrNo.equalsIgnoreCase("no")) { + AssistantMessage notRelatedSysMessage = new AssistantMessage( + qaEvaluatorNotRelatedResource.getContentAsString(StandardCharsets.UTF_8)); + prompt = new Prompt(List.of(userMessage, notRelatedSysMessage)); + String reasonForFailure = chatModel.call(prompt).getResult().getOutput().getContent(); + fail(reasonForFailure); + } + else { + logger.info("Answer is related to question."); + assertThat(yesOrNo).isEqualTo("YES"); + } + + } + +} diff --git a/models/spring-ai-wenxin/src/test/resources/data/acme/bikes.json b/models/spring-ai-wenxin/src/test/resources/data/acme/bikes.json new file mode 100644 index 00000000000..981f5a43482 --- /dev/null +++ b/models/spring-ai-wenxin/src/test/resources/data/acme/bikes.json @@ -0,0 +1,11 @@ +[ + { + "name": "E-Adrenaline 8.0 EX1", + "shortDescription": "a versatile and comfortable e-MTB designed for adrenaline enthusiasts who want to explore all types of terrain. It features a powerful motor and advanced suspension to provide a smooth and responsive ride, with a variety of customizable settings to fit any rider's needs.", + "description": "## Overview\r\nIt's right for you if...\r\nYou want to push your limits on challenging trails and terrain, with the added benefit of an electric assist to help you conquer steep climbs and rough terrain. You also want a bike with a comfortable and customizable fit, loaded with high-quality components and technology.\r\n\r\nThe tech you get\r\nA lightweight, full ADV Mountain Carbon frame with a customizable geometry, including an adjustable head tube and chainstay length. A powerful and efficient motor with a 375Wh battery that can assist up to 28 mph when it's on, and provides a smooth and seamless transition when it's off. A SRAM EX1 8-speed drivetrain, a RockShox Lyrik Ultimate fork, and a RockShox Super Deluxe Ultimate rear shock.\r\n\r\nThe final word\r\nOur E-Adrenaline 8.0 EX1 is the perfect bike for adrenaline enthusiasts who want to explore all types of terrain. It's versatile, comfortable, and loaded with advanced technology to provide a smooth and responsive ride, no matter where your adventures take you.\r\n\r\n\r\n## Features\r\nVersatile and customizable\r\nThe E-Adrenaline 8.0 EX1 features a customizable geometry, including an adjustable head tube and chainstay length, so you can fine-tune your ride to fit your needs and preferences. It also features a variety of customizable settings, including suspension tuning, motor assistance levels, and more.\r\n\r\nPowerful and efficient\r\nThe bike is equipped with a powerful and efficient motor that provides a smooth and seamless transition between human power and electric assist. It can assist up to 28 mph when it's on, and provides zero drag when it's off.\r\n\r\nAdvanced suspension\r\nThe E-Adrenaline 8.0 EX1 features a RockShox Lyrik Ultimate fork and a RockShox Super Deluxe Ultimate rear shock, providing advanced suspension technology to absorb shocks and bumps on any terrain. The suspension is also customizable to fit your riding style and preferences.\r\n\r\n\r\n## Specs\r\nFrameset\r\nFrame ADV Mountain Carbon main frame & stays, adjustable head tube and chainstay length, tapered head tube, Knock Block, Control Freak internal routing, Boost148, 150mm travel\r\nFork RockShox Lyrik Ultimate, DebonAir spring, Charger 2.1 RC2 damper, remote lockout, tapered steerer, 42mm offset, Boost110, 15mm Maxle Stealth, 160mm travel\r\nShock RockShox Super Deluxe Ultimate, DebonAir spring, Thru Shaft 3-position damper, 230x57.5mm\r\n\r\nWheels\r\nWheel front Bontrager Line Elite 30, ADV Mountain Carbon, Tubeless Ready, 6-bolt, Boost110, 15mm thru axle\r\nWheel rear Bontrager Line Elite 30, ADV Mountain Carbon, Tubeless Ready, 54T Rapid Drive, 6-bolt, Shimano MicroSpline freehub, Boost148, 12mm thru axle\r\nSkewer rear Bontrager Switch thru axle, removable lever\r\nTire Bontrager XR5 Team Issue, Tubeless Ready, Inner Strength sidewall, aramid bead, 120tpi, 29x2.50''\r\nTire part Bontrager TLR sealant, 6oz\r\n\r\nDrivetrain\r\nShifter SRAM EX1, 8 speed\r\nRear derailleur SRAM EX1, 8 speed\r\nCrank Bosch Performance CX, magnesium motor body, 250 watt, 75 Nm torque\r\nChainring SRAM EX1, 18T, steel\r\nCassette SRAM EX1, 11-48, 8 speed\r\nChain SRAM EX1, 8 speed\r\n\r\nComponents\r\nSaddle Bontrager Arvada, hollow chromoly rails, 138mm width\r\nSeatpost Bontrager Line Elite Dropper, internal routing, 31.6mm\r\nHandlebar Bontrager Line Pro, ADV Carbon, 35mm, 27.5mm rise, 780mm width\r\nGrips Bontrager XR Trail Elite, alloy lock-on\r\nStem Bontrager Line Pro, 35mm, Knock Block, Blendr compatible, 0 degree, 50mm length\r\nHeadset Knock Block Integrated, 62-degree radius, cartridge bearing, 1-1\/8'' top, 1.5'' bottom\r\nBrake SRAM G2 RSC hydraulic disc, carbon levers\r\nBrake rotor SRAM Centerline, centerlock, round edge, 200mm\r\n\r\nAccessories\r\nE-bike system Bosch Performance CX, magnesium motor body, 250 watt, 75 Nm torque\r\nBattery Bosch PowerTube 625, 625Wh\r\nCharger Bosch 4A standard charger\r\nController Bosch Kiox with Anti-theft solution, Bluetooth connectivity, 1.9'' display\r\nTool Bontrager Switch thru axle, removable lever\r\n\r\nWeight\r\nWeight M - 20.25 kg \/ 44.6 lbs (with TLR sealant, no tubes)\r\nWeight limit This bike has a maximum total weight limit (combined weight of bicycle, rider, and cargo) of 300 pounds (136 kg).\r\n\r\n## Sizing & fit\r\n\r\n| Size | Rider Height | Inseam |\r\n|:----:|:------------------------:|:--------------------:|\r\n| S | 155 - 170 cm 5'1\" - 5'7\" | 73 - 80 cm 29\" - 31.5\" |\r\n| M | 163 - 178 cm 5'4\" - 5'10\" | 77 - 83 cm 30.5\" - 32.5\" |\r\n| L | 176 - 191 cm 5'9\" - 6'3\" | 83 - 89 cm 32.5\" - 35\" |\r\n| XL | 188 - 198 cm 6'2\" - 6'6\" | 88 - 93 cm 34.5\" - 36.5\" |\r\n\r\n\r\n## Geometry\r\n\r\nAll measurements provided in cm unless otherwise noted.\r\nSizing table\r\n| Frame size letter | S | M | L | XL |\r\n|---------------------------|-------|-------|-------|-------|\r\n| Actual frame size | 15.8 | 17.8 | 19.8 | 21.8 |\r\n| Wheel size | 29\" | 29\" | 29\" | 29\" |\r\n| A \u2014 Seat tube | 40.0 | 42.5 | 47.5 | 51.0 |\r\n| B \u2014 Seat tube angle | 72.5\u00B0 | 72.8\u00B0 | 73.0\u00B0 | 73.0\u00B0 |\r\n| C \u2014 Head tube length | 9.5 | 10.5 | 11.0 | 11.5 |\r\n| D \u2014 Head angle | 67.8\u00B0 | 67.8\u00B0 | 67.8\u00B0 | 67.8\u00B0 |\r\n| E \u2014 Effective top tube | 59.0 | 62.0 | 65.0 | 68.0 |\r\n| F \u2014 Bottom bracket height | 32.5 | 32.5 | 32.5 | 32.5 |\r\n| G \u2014 Bottom bracket drop | 5.5 | 5.5 | 5.5 | 5.5 |\r\n| H \u2014 Chainstay length | 45.0 | 45.0 | 45.0 | 45.0 |\r\n| I \u2014 Offset | 4.5 | 4.5 | 4.5 | 4.5 |\r\n| J \u2014 Trail | 11.0 | 11.0 | 11.0 | 11.0 |\r\n| K \u2014 Wheelbase | 113.0 | 117.0 | 120.0 | 123.0 |\r\n| L \u2014 Standover | 77.0 | 77.0 | 77.0 | 77.0 |\r\n| M \u2014 Frame reach | 41.0 | 44.5 | 47.5 | 50.0 |\r\n| N \u2014 Frame stack | 61.0 | 62.0 | 62.5 | 63.0 |", + "price": 1499.99, + "tags": [ + "bicycle" + ] + } +] \ No newline at end of file diff --git a/models/spring-ai-wenxin/src/test/resources/prompts/acme/system-qa.st b/models/spring-ai-wenxin/src/test/resources/prompts/acme/system-qa.st new file mode 100644 index 00000000000..44db6f210d6 --- /dev/null +++ b/models/spring-ai-wenxin/src/test/resources/prompts/acme/system-qa.st @@ -0,0 +1,7 @@ +You're assisting with questions about products in a bicycle catalog. +Use the information from the DOCUMENTS section to provide accurate answers. +The the answer involves referring to the price or the dimension of the bicycle, include the bicycle name in the response. +If unsure, simply state that you don't know. + +DOCUMENTS: +{documents} \ No newline at end of file diff --git a/models/spring-ai-wenxin/src/test/resources/prompts/eval/qa-evaluator-accurate-answer.st b/models/spring-ai-wenxin/src/test/resources/prompts/eval/qa-evaluator-accurate-answer.st new file mode 100644 index 00000000000..56270359545 --- /dev/null +++ b/models/spring-ai-wenxin/src/test/resources/prompts/eval/qa-evaluator-accurate-answer.st @@ -0,0 +1,3 @@ +You are an AI assistant who helps users to evaluate if the answers to questions are accurate. +You will be provided with a QUESTION and an ANSWER. +Your goal is to evaluate the QUESTION and ANSWER and reply with a YES or NO answer. \ No newline at end of file diff --git a/models/spring-ai-wenxin/src/test/resources/prompts/eval/qa-evaluator-fact-based-answer.st b/models/spring-ai-wenxin/src/test/resources/prompts/eval/qa-evaluator-fact-based-answer.st new file mode 100644 index 00000000000..22fc3e88d14 --- /dev/null +++ b/models/spring-ai-wenxin/src/test/resources/prompts/eval/qa-evaluator-fact-based-answer.st @@ -0,0 +1,7 @@ +You are an AI evaluator. Your task is to verify if the provided ANSWER is a direct and accurate response to the given QUESTION. If the ANSWER is correct and directly answers the QUESTION, reply with "YES". If the ANSWER is not a direct response or is inaccurate, reply with "NO". + +For example: + +If the QUESTION is "What is the capital of France?" and the ANSWER is "Paris.", you should respond with "YES". +If the QUESTION is "What is the capital of France?" and the ANSWER is "France is in Europe.", respond with "NO". +Now, evaluate the following: diff --git a/models/spring-ai-wenxin/src/test/resources/prompts/eval/qa-evaluator-not-related-message.st b/models/spring-ai-wenxin/src/test/resources/prompts/eval/qa-evaluator-not-related-message.st new file mode 100644 index 00000000000..7c33e675e02 --- /dev/null +++ b/models/spring-ai-wenxin/src/test/resources/prompts/eval/qa-evaluator-not-related-message.st @@ -0,0 +1,4 @@ +You are an AI assistant who helps users to evaluate if the answers to questions are accurate. +You will be provided with a QUESTION and an ANSWER. +A previous evaluation has determined that QUESTION and ANSWER are not related. +Give an explanation as to why they are not related. \ No newline at end of file diff --git a/models/spring-ai-wenxin/src/test/resources/prompts/eval/user-evaluator-message.st b/models/spring-ai-wenxin/src/test/resources/prompts/eval/user-evaluator-message.st new file mode 100644 index 00000000000..b3fa3e902d2 --- /dev/null +++ b/models/spring-ai-wenxin/src/test/resources/prompts/eval/user-evaluator-message.st @@ -0,0 +1,6 @@ +The question and answer to evaluate are: + +QUESTION: ```{question}``` + +ANSWER: ```{answer}``` + diff --git a/pom.xml b/pom.xml index fad014fe95d..553ef06f18a 100644 --- a/pom.xml +++ b/pom.xml @@ -70,6 +70,7 @@ models/spring-ai-openai models/spring-ai-postgresml models/spring-ai-qianfan + models/spring-ai-wenxin models/spring-ai-stability-ai models/spring-ai-transformers models/spring-ai-vertex-ai-gemini @@ -96,6 +97,7 @@ spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-palm2 spring-ai-spring-boot-starters/spring-ai-starter-watsonx-ai spring-ai-spring-boot-starters/spring-ai-starter-zhipuai + spring-ai-spring-boot-starters/spring-ai-starter-wenxin spring-ai-spring-boot-starters/spring-ai-starter-moonshot vector-stores/spring-ai-opensearch-store spring-ai-spring-boot-starters/spring-ai-starter-opensearch-store @@ -187,6 +189,7 @@ 0.5.0 2.10.1 + 5.3.1 diff --git a/spring-ai-bom/pom.xml b/spring-ai-bom/pom.xml index 7f555636e17..1c4ff3ee879 100644 --- a/spring-ai-bom/pom.xml +++ b/spring-ai-bom/pom.xml @@ -147,6 +147,11 @@ spring-ai-qianfan ${project.version} + + org.springframework.ai + spring-ai-wenxin + ${project.version} + @@ -505,6 +510,12 @@ ${project.version} + + + org.springframework.ai + spring-ai-wenxin-spring-boot-starter + ${project.version} + diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/AbstractMessage.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/AbstractMessage.java index 05a89117c6b..7a24698c021 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/AbstractMessage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/AbstractMessage.java @@ -21,6 +21,7 @@ import java.io.IOException; import java.io.InputStream; +import java.io.Serializable; import java.nio.charset.Charset; import java.util.HashMap; import java.util.Map; @@ -33,7 +34,9 @@ * * @see Message */ -public abstract class AbstractMessage implements Message { +public abstract class AbstractMessage implements Message, Serializable { + + private static final long serialVersionUID = 1L; public static final String MESSAGE_TYPE = "messageType"; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/AssistantMessage.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/AssistantMessage.java index 1fb46b580e0..adaabd51837 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/AssistantMessage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/AssistantMessage.java @@ -15,6 +15,7 @@ */ package org.springframework.ai.chat.messages; +import java.io.Serializable; import java.util.List; import java.util.Map; import java.util.Objects; @@ -31,7 +32,9 @@ * @author Christian Tzolov * @since 1.0.0 */ -public class AssistantMessage extends AbstractMessage { +public class AssistantMessage extends AbstractMessage implements Serializable { + + private static final long serialVersionUID = 1L; public record ToolCall(String id, String type, String name, String arguments) { } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/SystemMessage.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/SystemMessage.java index ddcff796678..936479b868e 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/SystemMessage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/SystemMessage.java @@ -15,6 +15,7 @@ */ package org.springframework.ai.chat.messages; +import java.io.Serializable; import java.util.Map; import java.util.Objects; @@ -27,7 +28,9 @@ * generative to behave like a certain character or to provide answers in a specific * format. */ -public class SystemMessage extends AbstractMessage { +public class SystemMessage extends AbstractMessage implements Serializable { + + private static final long serialVersionUID = 1L; public SystemMessage(String textContent) { super(MessageType.SYSTEM, textContent, Map.of()); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/ToolResponseMessage.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/ToolResponseMessage.java index 42f91f9df54..242fbb17d7f 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/ToolResponseMessage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/ToolResponseMessage.java @@ -15,6 +15,7 @@ */ package org.springframework.ai.chat.messages; +import java.io.Serializable; import java.util.List; import java.util.Map; import java.util.Objects; @@ -26,7 +27,9 @@ * @author Christian Tzolov * @since 1.0.0 */ -public class ToolResponseMessage extends AbstractMessage { +public class ToolResponseMessage extends AbstractMessage implements Serializable { + + private static final long serialVersionUID = 1L; public record ToolResponse(String id, String name, String responseData) { }; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/UserMessage.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/UserMessage.java index 53c32425722..0a3a63cf16c 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/UserMessage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/UserMessage.java @@ -15,6 +15,7 @@ */ package org.springframework.ai.chat.messages; +import java.io.Serializable; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; @@ -31,7 +32,9 @@ * end-user or developer. They represent questions, prompts, or any input that you want * the generative to respond to. */ -public class UserMessage extends AbstractMessage implements MediaContent { +public class UserMessage extends AbstractMessage implements MediaContent, Serializable { + + private static final long serialVersionUID = 1L; protected final List media; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java index 809b1cc82a5..7d1c4cb34e6 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java @@ -17,6 +17,7 @@ import java.beans.PropertyDescriptor; import java.lang.reflect.Field; +import java.lang.reflect.Type; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; @@ -333,7 +334,7 @@ private static String toGetName(String name) { * @param toUpperCaseTypeValues if true, the type values are converted to upper case. * @return the generated JSON Schema as a String. */ - public static String getJsonSchema(Class clazz, boolean toUpperCaseTypeValues) { + public static String getJsonSchema(Type type, boolean toUpperCaseTypeValues) { if (SCHEMA_GENERATOR_CACHE.get() == null) { @@ -352,7 +353,7 @@ public static String getJsonSchema(Class clazz, boolean toUpperCaseTypeValues SCHEMA_GENERATOR_CACHE.compareAndSet(null, generator); } - ObjectNode node = SCHEMA_GENERATOR_CACHE.get().generateSchema(clazz); + ObjectNode node = SCHEMA_GENERATOR_CACHE.get().generateSchema(type); if (toUpperCaseTypeValues) { // Required for OpenAPI 3.0 (at least Vertex AI // version of it). toUpperCaseTypeValues(node); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallback.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallback.java index 6bd639c883e..f9aaac5afc0 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallback.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallback.java @@ -15,9 +15,11 @@ */ package org.springframework.ai.model.function; +import java.lang.reflect.Type; import java.util.function.Function; import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JavaType; import com.fasterxml.jackson.databind.ObjectMapper; import org.springframework.util.Assert; @@ -43,7 +45,7 @@ abstract class AbstractFunctionCallback implements Function, Functio private final String description; - private final Class inputType; + private final Type inputType; private final String inputTypeSchema; @@ -66,7 +68,7 @@ abstract class AbstractFunctionCallback implements Function, Functio * @param objectMapper Used to convert the function's input and output types to and * from JSON. */ - protected AbstractFunctionCallback(String name, String description, String inputTypeSchema, Class inputType, + protected AbstractFunctionCallback(String name, String description, String inputTypeSchema, Type inputType, Function responseConverter, ObjectMapper objectMapper) { Assert.notNull(name, "Name must not be null"); Assert.notNull(description, "Description must not be null"); @@ -107,9 +109,10 @@ public String call(String functionArguments) { return this.andThen(this.responseConverter).apply(request); } - private T fromJson(String json, Class targetClass) { + private T fromJson(String json, Type targetClass) { try { - return this.objectMapper.readValue(json, targetClass); + JavaType javaType = objectMapper.getTypeFactory().constructType(targetClass); + return this.objectMapper.readValue(json, javaType); } catch (JsonProcessingException e) { throw new RuntimeException(e); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackWrapper.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackWrapper.java index bfa9c9c3c28..337ea13a639 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackWrapper.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackWrapper.java @@ -15,6 +15,7 @@ */ package org.springframework.ai.model.function; +import java.lang.reflect.Type; import java.util.function.Function; import com.fasterxml.jackson.databind.DeserializationFeature; @@ -36,7 +37,7 @@ public class FunctionCallbackWrapper extends AbstractFunctionCallback function; - private FunctionCallbackWrapper(String name, String description, String inputTypeSchema, Class inputType, + private FunctionCallbackWrapper(String name, String description, String inputTypeSchema, Type inputType, Function responseConverter, ObjectMapper objectMapper, Function function) { super(name, description, inputTypeSchema, inputType, responseConverter, objectMapper); Assert.notNull(function, "Function must not be null"); @@ -44,8 +45,8 @@ private FunctionCallbackWrapper(String name, String description, String inputTyp } @SuppressWarnings("unchecked") - private static Class resolveInputType(Function function) { - return (Class) TypeResolverHelper.getFunctionInputClass((Class>) function.getClass()); + private static Type resolveInputType(Function function) { + return TypeResolverHelper.getFunctionInputClass((Class>) function.getClass()); } @Override @@ -69,7 +70,9 @@ public enum SchemaType { private String description; - private Class inputType; + // private Class inputType; + + private Type inputType; private final Function function; @@ -141,6 +144,7 @@ public FunctionCallbackWrapper build() { Assert.notNull(this.objectMapper, "ObjectMapper must not be null"); if (this.inputType == null) { + // this.inputType = resolveInputType(this.function); this.inputType = resolveInputType(this.function); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/TypeResolverHelper.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/TypeResolverHelper.java index 1fa0736d3eb..62db288b748 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/TypeResolverHelper.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/TypeResolverHelper.java @@ -37,7 +37,7 @@ public abstract class TypeResolverHelper { * @param functionClass The function class. * @return The input class of the function. */ - public static Class getFunctionInputClass(Class> functionClass) { + public static Type getFunctionInputClass(Class> functionClass) { return getFunctionArgumentClass(functionClass, 0); } @@ -46,7 +46,7 @@ public static Class getFunctionInputClass(Class> fun * @param functionClass The function class. * @return The output class of the function. */ - public static Class getFunctionOutputClass(Class> functionClass) { + public static Type getFunctionOutputClass(Class> functionClass) { return getFunctionArgumentClass(functionClass, 1); } @@ -56,13 +56,14 @@ public static Class getFunctionOutputClass(Class> fu * @param argumentIndex The index of the argument whose class should be retrieved. * @return The class of the specified function argument. */ - public static Class getFunctionArgumentClass(Class> functionClass, int argumentIndex) { + public static Type getFunctionArgumentClass(Class> functionClass, int argumentIndex) { Type type = TypeResolver.reify(Function.class, functionClass); var argumentType = type instanceof ParameterizedType ? ((ParameterizedType) type).getActualTypeArguments()[argumentIndex] : Object.class; - return toRawClass(argumentType); + // return toRawClass(argumentType); + return argumentType; } /** diff --git a/spring-ai-core/src/test/java/org/springframework/ai/model/function/TypeResolverHelperTests.java b/spring-ai-core/src/test/java/org/springframework/ai/model/function/TypeResolverHelperTests.java index 76622a22281..8f81696953b 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/model/function/TypeResolverHelperTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/model/function/TypeResolverHelperTests.java @@ -15,6 +15,7 @@ */ package org.springframework.ai.model.function; +import java.lang.reflect.Type; import java.util.function.Function; import com.fasterxml.jackson.annotation.JsonClassDescription; @@ -36,20 +37,20 @@ public class TypeResolverHelperTests { @Test public void testGetFunctionInputType() { - Class inputType = TypeResolverHelper.getFunctionInputClass(MockWeatherService.class); + Type inputType = TypeResolverHelper.getFunctionInputClass(MockWeatherService.class); assertThat(inputType).isEqualTo(Request.class); } @Test public void testGetFunctionOutputType() { - Class outputType = TypeResolverHelper.getFunctionOutputClass(MockWeatherService.class); + Type outputType = TypeResolverHelper.getFunctionOutputClass(MockWeatherService.class); assertThat(outputType).isEqualTo(Response.class); } @Test public void testGetFunctionInputTypeForInstance() { MockWeatherService service = new MockWeatherService(); - Class inputType = TypeResolverHelper.getFunctionInputClass(service.getClass()); + Type inputType = TypeResolverHelper.getFunctionInputClass(service.getClass()); assertThat(inputType).isEqualTo(Request.class); } diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/wenxin-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/wenxin-chat.adoc new file mode 100644 index 00000000000..762e180de8e --- /dev/null +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/wenxin-chat.adoc @@ -0,0 +1,192 @@ += Wenxin Chat + +Spring AI supports Wenxin Chat the AI language model by Baidu. + +== Pre-requisites == + +You will need to create an API with Baidu to access Wenxin Models. +Create an account at https://login.bce.baidu.com[Baidu signup page] and generate the access_key & secretKey on the https://console.bce.baidu.com/iam/#/iam/accesslist[AK/SK Page] +The Spring AI Project defines a configuration property named "spring.ai.wenxin.access-key" and "spring.ai.wenxin.secret-key" to store the access_key & secretKey respectively. +Exporting an enviroment variable named "SPRING_AI_WENXIN_ACCESS_KEY" and "SPRING_AI_WENXIN_SECRET_KEY" with the access_key & secretKey respectively will also work. + +[source.shell] +---- +export SPRING_AI_WENXIN_ACCESS_KEY= +export SPRING_AI_WENXIN_SECRET_KEY= +---- + +=== Add Repositories and BOM + +Spring AI artifacts are published in Spring Milestone and Snapshot repositories. +Refer to the xref:getting-started.adoc#repositories[Repositories] section to add these repositories to your build system. + +To help with dependency management, Spring AI provides a BOM (bill of materials) to ensure that a consistent version of Spring AI is used throughout the entire project. Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build system. + + +== Auto-configuration + +Spring AI provides Spring Boot auto-configuration for the Wenxin Chat Client. +To enable it add the following dependency to your project's Maven `pom.xml` file: + +[source, xml] +---- + + org.springframework.ai + spring-ai-wenxin-spring-boot-starter + +---- + +or to your Gradle `build.gradle` build file. + +[source,groovy] +---- +dependencies { + implementation 'org.springframework.ai:spring-ai-wenxin-spring-boot-starter' +} +---- + +TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file. + +=== Chat Properties + +==== Retry Properties + +The prefix `spring.ai.retry` is used as the property prefix that lets you configure the retry mechanism for the Wenxin chat model. + +[cols="3,5,1"] +|==== +| Property | Description | Default + +| spring.ai.retry.max-attempts | Maximum number of retry attempts. | 10 +| spring.ai.retry.backoff.initial-interval | Initial sleep duration for the exponential backoff policy. | 2 sec. +| spring.ai.retry.backoff.multiplier | Backoff interval multiplier. | 5 +| spring.ai.retry.backoff.max-interval | Maximum backoff duration. | 3 min. +| spring.ai.retry.on-client-errors | If false, throw a NonTransientAiException, and do not attempt retry for `4xx` client error codes | false +| spring.ai.retry.exclude-on-http-codes | List of HTTP status codes that should not trigger a retry (e.g. to throw NonTransientAiException). | empty +| spring.ai.retry.on-http-codes | List of HTTP status codes that should trigger a retry (e.g. to throw TransientAiException). | empty +|==== + +==== Connection Properties + +The prefix `spring.ai.wenxin` is used as the property prefix that lets you connect to wenxin. + +[cols="3,5,1"] +|==== +| Property | Description | Default + +| spring.ai.wenxin.base-url | The URL to connect to | https://aip.baidubce.com +| spring.ai.wenxin.access-key | The access Key | - +| spring.ai.wenxin.secret-key | The secret Key | - +|==== + +==== Configuration Properties + +The prefix `spring.ai.wenxin.chat` is the property prefix that lets you configure the chat model implementation for wenxin. + +[cols="3,5,1"] +|==== +| Property | Description | Default + +| spring.ai.wenxin.chat.enabled | Enable wenxin chat model. | true +| spring.ai.wenxin.chat.base-url | Optional overrides the spring.ai.wenxin.base-url to provide chat specific url | - +| spring.ai.wenxin.chat.access-key | Optional overrides the spring.ai.wenxin.access-key to provide chat specific access-key | - +| spring.ai.wenxin.chat.secret-key | Optional overrides the spring.ai.wenxin.secret-key to provide chat specific secret-key | - +| spring.ai.wenxin.chat.options.model | This is the Wenxin Chat model to use. 'ERNIE-3.5-8K','ERNIE-3.5-8K-0205','ERNIE-3.5-8K-Preview','ERNIE-3.5-8K-0329','ERNIE-3.5-128K','ERNIE-3.5-8K-0613'. See the https://cloud.baidu.com/doc/WENXINWORKSHOP/s/jlil56u11[models] page for more information. | `ERNIE-3.5-8K` +| spring.ai.wenxin.chat.options.penalty_score | By penalizing the tokens already generated, we decrease the occurrence of repeated generations. Here's what you need to know: (1) A higher value means a stronger penalty. (2) Default is 1.0, can be set between [1.0, 2.0].. | 0.8 +| spring.ai.wenxin.chat.options.max_output_tokens | Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. | 0.0f +| spring.ai.wenxin.chat.options.response_format | Specify the format of the response content as follows:(1) Optional values:- json_object: returns in JSON format, may not always give desired results- text: returns in text format(2) If the response_format parameter is not specified, the default value is text. | - +| spring.ai.wenxin.chat.options.stop | Generate a stop sign. When the model generates a result ending with any element in stop, stop text generation. Note: (1) Each element is no more than 20 characters long. (2) Up to 4 elements. | - +| spring.ai.wenxin.chat.options.temperature | Explanation: (1) Higher values will make the output more random, while lower values will make it more concentrated and definite. (2) Default is 0.8, range (0, 1.0], cannot be 0. | 1 +| spring.ai.wenxin.chat.options.top_p | Explanation:(1) The larger the value, the more diverse the output text will be.(2) Default is 0.8, with a range of [0, 1.0]. | - +| spring.ai.wenxin.chat.options.functions | A description list of triggerable functions, stating: (1) Unlimited number of supported functions (2) Length limit, total content in message's content, functions, and system fields combined cannot exceed 20000 characters, and must not exceed 5120 tokens. | - +| spring.ai.wenxin.chat.options.tool_choice | In the context of function calls, prompt the large model to select the specified function (not mandatory), specifying that the specified function name must exist in functions. | - +| spring.ai.wenxin.chat.options.user_id | User's unique identifier. | - +| spring.ai.wenxin.chat.options.system | Model personas are primarily used for persona settings, for example, if you are an AI assistant produced by XXX company, explain: (1) Length limit, the total length of content in the message, functions, and system fields cannot exceed 20,000 characters, and must not exceed 5120 tokens (2) If using system and functions at the same time, the effectiveness of use may not be guaranteed temporarily, continuous optimization is ongoing. | - +| spring.ai.wenxin.chat.options.disable_search | Whether to force close the real-time search feature, default is false, which means not to close. | - +| spring.ai.wenxin.chat.options.enable_citation | Do you want to turn on the return of superscripts at the top right corner, which means: (1) If turned on, there's a chance of triggering search origin information search_info, the content of search_info can be found in the response parameter introduction (2) Default is false, not turned on. | - +| spring.ai.wenxin.chat.options.enable_trace | Do you want to return search trace information? (1) If enabled, in scenarios where search enhancement is triggered, it will return search trace information (search_info). Please refer to the response parameters description for the content of search_info. (2) Default is set to false, indicating no return. | - +|==== + +TIP: All properties prefixed with `spring.ai.wenxin.chat.options` can be overridden at runtime by adding a request specific <> to the `Prompt` call. + +== Runtime Options [[chat-options]] + +The https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinChatOptions.java[WenxinChatOptions.java] provides model configurations, such as the model to use, the temperature, the frequency penalty, etc. + +On start-up, the default options can be configured with the `WenxinChatOptions(api, options)` constructor or the `spring.ai.wenxin.chat.options.*` properties. + +At run-time you can override the default options by adding new, request specific, options to the `Prompt` call. +For example to override the default model and temperature for a specific request: + +[source,java] +---- +ChatResponse response = chatModel.call( + new Prompt( + "Generate the names of 5 famous pirates.", + OpenAiChatOptions.builder() + .withModel("completions") + .withTemperature(0.4) + .build() + )); +---- + +TIP: In addition to the model specific https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-wenxin/src/main/java/org/springframework/ai/wenxin/WenxinChatOptions.java[WenxinChatOptions] you can use a portable https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java[ChatOptions] instance, created with the https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptionsBuilder.java[ChatOptionsBuilder#builder()]. + +== Function Calling + +You can register custom Java functions with the OpenAiChatModel and have the Wenxin model intelligently choose to output a JSON object containing arguments to call one or many of the registered functions. +This is a powerful technique to connect the LLM capabilities with external tools and APIs. + + +== Sample Controller + +https://start.spring.io/[Create] a new Spring Boot project and add the `spring-ai-wenxin-spring-boot-starter` to your pom (or gradle) dependencies. + +Add a `application.properties` file, under the `src/main/resources` directory, to enable and configure the Wenxin chat model: + +[source,application.properties] +---- +spring.ai.wenxin.accesss-key=YOUR_ACCESS_KEY +spring.ai.wenxin.secret-key=YOUR_SECRET_KEY +spring.ai.wenxin.chat.options.model=gpt-3.5-turbo +spring.ai.wenxin.chat.options.temperature=0.7 +---- + +TIP: replace the `api-key` with your Wenxin credentials. + +This will create a `WenxinChatModel` implementation that you can inject into your class. +Here is an example of a simple `@Controller` class that uses the chat model for text generations. + +[source.java] +---- +@RestController +@RequestMapping("/wenxin") +public class WenxinSimpleAiController { + + + private final ChatModel chatModel; + + private final StreamingChatModel streamingChatModel; + + public WenxinSimpleAiController(@Qualifier("wenxinChatModel") ChatModel chatModel, + @Qualifier("wenxinChatModel") StreamingChatModel streamingChatModel) { + this.chatModel = chatModel; + this.streamingChatModel = streamingChatModel; + } + + @GetMapping("/simple") + public Map completion( + @RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { + return Map.of("generation", chatModel.call(message)); + + } + + @GetMapping("/stream") + public Flux> stream( + @RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { + return streamingChatModel.stream(message).map(data -> ServerSentEvent.builder(data).build()); + } +} +---- + + diff --git a/spring-ai-spring-boot-autoconfigure/pom.xml b/spring-ai-spring-boot-autoconfigure/pom.xml index 250d6246ca3..cb34ce34575 100644 --- a/spring-ai-spring-boot-autoconfigure/pom.xml +++ b/spring-ai-spring-boot-autoconfigure/pom.xml @@ -42,6 +42,13 @@ true + + org.springframework.ai + spring-ai-wenxin + ${project.parent.version} + true + + org.springframework.ai spring-ai-openai diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/wenxin/WenxinAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/wenxin/WenxinAutoConfiguration.java new file mode 100644 index 00000000000..867c9cf563d --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/wenxin/WenxinAutoConfiguration.java @@ -0,0 +1,104 @@ +package org.springframework.ai.autoconfigure.wenxin; + +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.model.function.FunctionCallbackContext; +import org.springframework.ai.wenxin.WenxinChatModel; +import org.springframework.ai.wenxin.WenxinEmbeddingModel; +import org.springframework.ai.wenxin.api.WenxinApi; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.ImportAutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; +import org.springframework.boot.autoconfigure.web.reactive.function.client.WebClientAutoConfiguration; +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.context.ApplicationContext; +import org.springframework.context.annotation.Bean; +import org.springframework.retry.support.RetryTemplate; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; +import org.springframework.web.client.ResponseErrorHandler; +import org.springframework.web.client.RestClient; +import org.springframework.web.reactive.function.client.WebClient; + +import java.util.List; + +/** + * @author lvchzh + * @since 1.0.0 + */ +@AutoConfiguration(after = { RestClientAutoConfiguration.class, WebClientAutoConfiguration.class, + SpringAiRetryAutoConfiguration.class }) +@ConditionalOnClass(WenxinApi.class) +@EnableConfigurationProperties({ WenxinConnectionProperties.class, WenxinChatProperties.class, + WenxinEmbeddingProperties.class }) +@ImportAutoConfiguration(classes = { SpringAiRetryAutoConfiguration.class, RestClientAutoConfiguration.class, + WebClientAutoConfiguration.class }) +public class WenxinAutoConfiguration { + + @Bean + @ConditionalOnMissingBean + @ConditionalOnProperty(prefix = WenxinChatProperties.CONFIG_PREFIX, name = "enable", havingValue = "true", + matchIfMissing = true) + public WenxinChatModel wenxinChatModel(WenxinConnectionProperties commonProperties, + WenxinChatProperties chatProperties, RestClient.Builder restClientBuilder, + WebClient.Builder webClientBuilder, List toolFunctionCallbacks, + FunctionCallbackContext functionCallbackContext, RetryTemplate retryTemplate, + ResponseErrorHandler responseErrorHandler) { + var wenxinApi = wenxinApi(chatProperties.getBaseUrl(), commonProperties.getBaseUrl(), + chatProperties.getAccessKey(), commonProperties.getAccessKey(), chatProperties.getSecretKey(), + commonProperties.getSecretKey(), restClientBuilder, webClientBuilder, responseErrorHandler); + + if (!CollectionUtils.isEmpty(toolFunctionCallbacks)) { + chatProperties.getOptions().getFunctionCallbacks().addAll(toolFunctionCallbacks); + } + + return new WenxinChatModel(wenxinApi, chatProperties.getOptions(), functionCallbackContext, retryTemplate); + + } + + @Bean + @ConditionalOnMissingBean + @ConditionalOnProperty(prefix = WenxinEmbeddingProperties.CONFIG_PREFIX, name = "enable", havingValue = "true", + matchIfMissing = true) + public WenxinEmbeddingModel wenxinEmbeddingModel(WenxinConnectionProperties commonProperties, + WenxinEmbeddingProperties embeddingProperties, RestClient.Builder restClientBuilder, + WebClient.Builder webClientBuilder, RetryTemplate retryTemplate, + ResponseErrorHandler responseErrorHandler) { + var wenxinApi = wenxinApi(embeddingProperties.getBaseUrl(), commonProperties.getBaseUrl(), + embeddingProperties.getAccessKey(), commonProperties.getAccessKey(), embeddingProperties.getSecretKey(), + commonProperties.getSecretKey(), restClientBuilder, webClientBuilder, responseErrorHandler); + + return new WenxinEmbeddingModel(wenxinApi, embeddingProperties.getMetadataMode(), + embeddingProperties.getOptions(), retryTemplate); + } + + private WenxinApi wenxinApi(String chatBaseUrl, String commonBaseUrl, String accessKey, String commonAccessKey, + String secretKey, String commonSecretKey, RestClient.Builder restClientBuilder, + WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler) { + + String resolvedChatBaseUrl = StringUtils.hasText(chatBaseUrl) ? chatBaseUrl : commonBaseUrl; + Assert.hasText(resolvedChatBaseUrl, "The Wenxin API base URL must be set!"); + + String resolvedAccessKey = StringUtils.hasText(secretKey) ? secretKey : commonAccessKey; + Assert.hasText(resolvedAccessKey, "The Wenxin API client ID must be set!"); + + String resolvedSecretKey = StringUtils.hasText(accessKey) ? accessKey : commonSecretKey; + Assert.hasText(resolvedSecretKey, "The Wenxin API client secret must be set!"); + + return new WenxinApi(resolvedChatBaseUrl, restClientBuilder, webClientBuilder, responseErrorHandler, + resolvedAccessKey, resolvedSecretKey); + } + + @Bean + @ConditionalOnMissingBean + public FunctionCallbackContext springAiFunctionCallbackContext(ApplicationContext context) { + FunctionCallbackContext manager = new FunctionCallbackContext(); + manager.setApplicationContext(context); + return manager; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/wenxin/WenxinChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/wenxin/WenxinChatProperties.java new file mode 100644 index 00000000000..43bde342831 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/wenxin/WenxinChatProperties.java @@ -0,0 +1,43 @@ +package org.springframework.ai.autoconfigure.wenxin; + +import org.springframework.ai.wenxin.WenxinChatOptions; +import org.springframework.boot.context.properties.ConfigurationProperties; + +/** + * @author lvchzh + * @since 1.0.0 + */ +@ConfigurationProperties(WenxinChatProperties.CONFIG_PREFIX) +public class WenxinChatProperties extends WenxinParentProperties { + + public static final String CONFIG_PREFIX = "spring.ai.wenxin.chat"; + + public static final String DEFAULT_CHAT_MODEL = "completions"; + + private static final Double DEFAULT_TEMPERATURE = 0.7; + + private boolean enable = true; + + //@formatter:off + private WenxinChatOptions options = WenxinChatOptions.builder() + .withModel(DEFAULT_CHAT_MODEL) + .withTemperature(DEFAULT_TEMPERATURE.floatValue()) + .build(); + //@formatter:on + public WenxinChatOptions getOptions() { + return options; + } + + public void setOptions(WenxinChatOptions options) { + this.options = options; + } + + public boolean isEnable() { + return enable; + } + + public void setEnable(boolean enable) { + this.enable = enable; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/wenxin/WenxinConnectionProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/wenxin/WenxinConnectionProperties.java new file mode 100644 index 00000000000..f64cc2059af --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/wenxin/WenxinConnectionProperties.java @@ -0,0 +1,20 @@ +package org.springframework.ai.autoconfigure.wenxin; + +import org.springframework.boot.context.properties.ConfigurationProperties; + +/** + * @author lvchzh + * @since 1.0.0 + */ +@ConfigurationProperties(WenxinConnectionProperties.CONFIG_PREFIX) +public class WenxinConnectionProperties extends WenxinParentProperties { + + public static final String CONFIG_PREFIX = "spring.ai.wenxin"; + + public static final String DEFAULT_BASE_URL = "https://aip.baidubce.com"; + + public WenxinConnectionProperties() { + super.setBaseUrl(DEFAULT_BASE_URL); + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/wenxin/WenxinEmbeddingProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/wenxin/WenxinEmbeddingProperties.java new file mode 100644 index 00000000000..9213a92931a --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/wenxin/WenxinEmbeddingProperties.java @@ -0,0 +1,50 @@ +package org.springframework.ai.autoconfigure.wenxin; + +import org.springframework.ai.document.MetadataMode; +import org.springframework.ai.wenxin.WenxinEmbeddingOptions; +import org.springframework.boot.context.properties.ConfigurationProperties; + +/** + * @author lvchzh + * @since 1.0.0 + */ +@ConfigurationProperties(WenxinEmbeddingProperties.CONFIG_PREFIX) +public class WenxinEmbeddingProperties extends WenxinParentProperties { + + public static final String CONFIG_PREFIX = "spring.ai.wenxin.embedding"; + + public static final String DEFAULT_EMBEDDING_MODEL = "Embedding-V1"; + + private boolean enabled = true; + + private MetadataMode metadataMode = MetadataMode.EMBED; + + private WenxinEmbeddingOptions options = WenxinEmbeddingOptions.builder() + .withModel(DEFAULT_EMBEDDING_MODEL) + .build(); + + public WenxinEmbeddingOptions getOptions() { + return this.options; + } + + public void setOptions(WenxinEmbeddingOptions options) { + this.options = options; + } + + public MetadataMode getMetadataMode() { + return this.metadataMode; + } + + public void setMetadataMode(MetadataMode metadataMode) { + this.metadataMode = metadataMode; + } + + public boolean isEnabled() { + return this.enabled; + } + + public void setEnabled(boolean enabled) { + this.enabled = enabled; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/wenxin/WenxinParentProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/wenxin/WenxinParentProperties.java new file mode 100644 index 00000000000..99f2d374821 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/wenxin/WenxinParentProperties.java @@ -0,0 +1,39 @@ +package org.springframework.ai.autoconfigure.wenxin; + +/** + * @author lvchzh + * @since 1.0.0 + */ +public class WenxinParentProperties { + + private String baseUrl; + + private String accessKey; + + private String secretKey; + + public String getBaseUrl() { + return baseUrl; + } + + public void setBaseUrl(String baseUrl) { + this.baseUrl = baseUrl; + } + + public String getAccessKey() { + return accessKey; + } + + public void setAccessKey(String accessKey) { + this.accessKey = accessKey; + } + + public String getSecretKey() { + return secretKey; + } + + public void setSecretKey(String secretKey) { + this.secretKey = secretKey; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports b/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports index 9fe46cc2609..a7e155f2c3d 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports +++ b/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports @@ -1,4 +1,5 @@ org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration +org.springframework.ai.autoconfigure.wenxin.WenxinAutoConfiguration org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiAutoConfiguration org.springframework.ai.autoconfigure.stabilityai.StabilityAiImageAutoConfiguration org.springframework.ai.autoconfigure.transformers.TransformersEmbeddingModelAutoConfiguration diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-wenxin/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-wenxin/pom.xml new file mode 100644 index 00000000000..666d1545f91 --- /dev/null +++ b/spring-ai-spring-boot-starters/spring-ai-starter-wenxin/pom.xml @@ -0,0 +1,42 @@ + + + 4.0.0 + + org.springframework.ai + spring-ai + 1.0.0-SNAPSHOT + ../../pom.xml + + spring-ai-wenxin-spring-boot-starter + jar + Spring AI Starter - Wenxin + Spring AI Wenxin Auto Configuration + https://github.com/spring-projects/spring-ai + + + https://github.com/spring-projects/spring-ai + git://github.com/spring-projects/spring-ai.git + git@github.com:spring-projects/spring-ai.git + + + + + + org.springframework.boot + spring-boot-starter + + + + org.springframework.ai + spring-ai-spring-boot-autoconfigure + ${project.parent.version} + + + + org.springframework.ai + spring-ai-wenxin + ${project.parent.version} + + + +