Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
89f8432
feature:add Wenxin model client
lvchzh May 30, 2024
cabbdbf
Merge branch 'main' of github:spring-projects/spring-ai
lvchzh Jun 6, 2024
110dd67
chore: modify the spring-ai-wenxin model name
lvchzh Jun 6, 2024
aa23d42
Resolve Conflict
lvchzh Jun 18, 2024
2f88aac
Merge branch 'main' of github:spring-projects/spring-ai
lvchzh Jun 18, 2024
0474110
Merge branch 'main' of github:spring-projects/spring-ai
lvchzh Jun 19, 2024
27ab829
Merge branch 'main' of github:spring-projects/spring-ai
lvchzh Jun 20, 2024
710ef57
Merge branch 'main' of github:spring-projects/spring-ai
lvchzh Jun 20, 2024
6e0ad50
add unit test demo and doc
lvchzh Jun 25, 2024
3a17099
Merge branch 'main' of github:spring-projects/spring-ai
lvchzh Jun 27, 2024
63e1de7
Merge branch 'main' of github:spring-projects/spring-ai
lvchzh Jun 28, 2024
140e498
Merge branch 'main' of github:spring-projects/spring-ai
lvchzh Jul 3, 2024
cc70e15
Merge branch 'main' of github:spring-projects/spring-ai
lvchzh Jul 8, 2024
c33bcd0
Merge branch 'main' of github:spring-projects/spring-ai
lvchzh Jul 9, 2024
c7ee720
Merge branch 'main' of github:spring-projects/spring-ai
lvchzh Jul 16, 2024
984c934
Merge branch 'main' of github:spring-projects/spring-ai
lvchzh Jul 16, 2024
73a8a2f
Merge branch 'main' of github:spring-projects/spring-ai
lvchzh Jul 19, 2024
bb08b7e
resovle conflicts
lvchzh Jul 25, 2024
a5664f6
Merge branch 'main' of github:spring-projects/spring-ai
lvchzh Jul 25, 2024
71e8cee
fix bug #1118 & #1117
lvchzh Jul 25, 2024
2021b90
Merge branch 'main' of github:spring-projects/spring-ai
lvchzh Jul 26, 2024
f6cf270
Merge branch 'main' of github:spring-projects/spring-ai
lvchzh Jul 27, 2024
4bb12ff
Merge branch 'main' of github:spring-projects/spring-ai
lvchzh Jul 28, 2024
28ea002
Merge branch 'main' of github:spring-projects/spring-ai
lvchzh Jul 31, 2024
ce02832
Merge branch 'main' of github:spring-projects/spring-ai
lvchzh Aug 3, 2024
076d339
Merge branch 'main' of github:spring-projects/spring-ai
lvchzh Aug 7, 2024
7334957
Merge branch 'main' of github:spring-projects/spring-ai
lvchzh Aug 12, 2024
ae339df
Merge branch 'main' of github:spring-projects/spring-ai
lvchzh Aug 20, 2024
578b5ef
Merge branch 'main' of github:spring-projects/spring-ai
lvchzh Aug 21, 2024
f8f2d00
fix compile error
lvchzh Aug 21, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
106 changes: 106 additions & 0 deletions models/spring-ai-wenxin/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai</artifactId>
<version>1.0.0-SNAPSHOT</version>
<relativePath>../../pom.xml</relativePath>
</parent>
<artifactId>spring-ai-wenxin</artifactId>
<packaging>jar</packaging>
<name>Spring AI Model - Wenxin</name>
<description>Wenxin support</description>
<url>https://github.com/spring-projects/spring-ai</url>

<scm>
<url>https://github.com/spring-projects/spring-ai</url>
<connection>git://github.com/spring-projects/spring-ai.git</connection>
<developerConnection>[email protected]:spring-projects/spring-ai.git</developerConnection>
</scm>

<dependencies>

<!-- codec -->
<dependency>
<groupId>commons-codec</groupId>
<artifactId>commons-codec</artifactId>
</dependency>

<!-- production dependencies -->
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-core</artifactId>
<version>${project.parent.version}</version>
</dependency>

<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-retry</artifactId>
<version>${project.parent.version}</version>
</dependency>

<!-- NOTE: Required only by the @ConstructorBinding. -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot</artifactId>
</dependency>

<dependency>
<groupId>io.rest-assured</groupId>
<artifactId>json-path</artifactId>
</dependency>


<dependency>
<groupId>com.github.victools</groupId>
<artifactId>jsonschema-generator</artifactId>
<version>${victools.version}</version>
</dependency>

<dependency>
<groupId>com.github.victools</groupId>
<artifactId>jsonschema-module-jackson</artifactId>
<version>${victools.version}</version>
</dependency>

<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-context-support</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-logging</artifactId>
</dependency>

<!-- test dependencies -->
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-test</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>testcontainers</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>junit-jupiter</artifactId>
<scope>test</scope>
</dependency>

</dependencies>
<repositories>
<repository>
<id>maven_central</id>
<name>Maven Central</name>
<url>https://repo.maven.apache.org/maven2/</url>
</repository>
</repositories>

</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,273 @@
package org.springframework.ai.wenxin;

import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.RateLimit;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.ai.wenxin.api.WenxinApi;
import org.springframework.ai.wenxin.metadata.WenxinUsage;
import org.springframework.ai.wenxin.metadata.support.WenxinResponseHeaderExtractor;
import org.springframework.http.ResponseEntity;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import reactor.core.publisher.Flux;

import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

/**
* @author lvchzh
* @since 1.0.0
*/
public class WenxinChatModel extends
AbstractFunctionCallSupport<WenxinApi.ChatCompletionMessage, WenxinApi.ChatCompletionRequest, ResponseEntity<WenxinApi.ChatCompletion>>
implements ChatModel, StreamingChatModel {

// @formatter:off
private static final Logger logger = LoggerFactory.getLogger(WenxinChatModel.class);
private final RetryTemplate retryTemplate;
private final WenxinApi wenxinApi;
private WenxinChatOptions defaultOptions;

public WenxinChatModel(WenxinApi wenxinApi) {
this(wenxinApi,
WenxinChatOptions.builder().withModel(WenxinApi.DEFAULT_CHAT_MODEL).withTemperature(0.7f).build());
}

public WenxinChatModel(WenxinApi wenxinApi, WenxinChatOptions options) {
this(wenxinApi, options, null, RetryUtils.DEFAULT_RETRY_TEMPLATE);
}

public WenxinChatModel(WenxinApi wenxinApi, WenxinChatOptions options,
FunctionCallbackContext functionCallbackContext, RetryTemplate retryTemplate) {
super(functionCallbackContext);
Assert.notNull(wenxinApi, "WenxinApi must not be null");
Assert.notNull(options, "WenxinChatOptions must not be null");
Assert.notNull(retryTemplate, "RetryTemplate must not be null");
this.wenxinApi = wenxinApi;
this.defaultOptions = options;
this.retryTemplate = retryTemplate;
}

@Override
public ChatResponse call(Prompt prompt) {

WenxinApi.ChatCompletionRequest request = createRequest(prompt, false);
return this.retryTemplate.execute(ctx -> {

ResponseEntity<WenxinApi.ChatCompletion> completionEntity = this.callWithFunctionSupport(request);

var chatCompletion = completionEntity.getBody();

if (chatCompletion == null) {
logger.warn("No chat completion returned for prompt: {}", prompt);
return new ChatResponse(List.of());
}

RateLimit rateLimit = WenxinResponseHeaderExtractor.extractAiResponseHeaders(completionEntity);

Generation generation = new Generation(chatCompletion.result(), toMap(chatCompletion.id(),
chatCompletion));

List<Generation> generations = List.of(generation);

return new ChatResponse(generations,
from(chatCompletion,rateLimit,request));
});
}

public static ChatResponseMetadata from(WenxinApi.ChatCompletion result, RateLimit rateLimit,WenxinApi.ChatCompletionRequest request) {
Assert.notNull(result, "Wenxin ChatCompletionResult must not be null");
return ChatResponseMetadata.builder()
.withId(result.id())
.withUsage(WenxinUsage.from(result.usage()))
.withModel(request.model())
.withRateLimit(rateLimit)
.withKeyValue("created", result.created())
.withKeyValue("sentence_id",result.sentenceId())
.build();

}

@Override
public ChatOptions getDefaultOptions() {
return WenxinChatOptions.fromOptions(this.defaultOptions);
}

private Map<String, Object> toMap(String id, WenxinApi.ChatCompletion chatCompletion) {
Map<String, Object> map = new HashMap<>();
if (chatCompletion.finishReason() != null) {
map.put("finishReason", chatCompletion.finishReason().name());
}
map.put("id", id);
return map;
}

@Override
public Flux<ChatResponse> stream(Prompt prompt) {
WenxinApi.ChatCompletionRequest request = createRequest(prompt, true);

return this.retryTemplate.execute(ctx -> {

Flux<WenxinApi.ChatCompletionChunk> completionChunks = this.wenxinApi.chatCompletionStream(request);

return completionChunks.map(chunk -> chunkToChatCompletion(chunk)).map(chatCompletion -> {
try {
chatCompletion = handleFunctionCallOrReturn(request,
ResponseEntity.of(Optional.of(chatCompletion))).getBody();

@SuppressWarnings("null")
String id = chatCompletion.id();
String finish = chatCompletion.finishReason() != null ? chatCompletion.finishReason().name() :
null;

var generation = new Generation(chatCompletion.result(), Map.of("id", id, "finishReason", finish));
if (chatCompletion.finishReason() != null) {
generation = generation.withGenerationMetadata(
ChatGenerationMetadata.from(chatCompletion.finishReason().name(), null));
}
List<Generation> generations = List.of(generation);

return new ChatResponse(generations);
} catch (Exception e) {
logger.error("Error processing chat completion", e);
return new ChatResponse(List.of());
}
});
});
}

private WenxinApi.ChatCompletion chunkToChatCompletion(WenxinApi.ChatCompletionChunk chunk) {
return new WenxinApi.ChatCompletion(chunk.id(), "chat.completion", chunk.created(), chunk.sentenceId(),
chunk.isEnd(), chunk.isTruncated(), chunk.finishReason(), chunk.searchInfo(), chunk.result(),
chunk.needClearHistory(), chunk.flag(), chunk.banRound(), null, chunk.functionCall());
}

WenxinApi.ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {

Set<String> functionsForThisRequest = new HashSet<>();

List<WenxinApi.ChatCompletionMessage> chatCompletionMessages = prompt.getInstructions().stream()
.map(m -> new WenxinApi.ChatCompletionMessage(m.getContent(),
WenxinApi.Role.valueOf(m.getMessageType().name()))).toList();
WenxinApi.ChatCompletionRequest request = new WenxinApi.ChatCompletionRequest(chatCompletionMessages, stream);

if (prompt.getOptions() != null) {

WenxinChatOptions updateRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(),
ChatOptions.class, WenxinChatOptions.class);

Set<String> promptEnabledFunctions = this.handleFunctionCallbackConfigurations(updateRuntimeOptions,
IS_RUNTIME_CALL);

functionsForThisRequest.addAll(promptEnabledFunctions);

request = ModelOptionsUtils.merge(updateRuntimeOptions, request,
WenxinApi.ChatCompletionRequest.class);
} else {
throw new IllegalArgumentException("Prompt options are not of type ChatOptions: " +
prompt.getOptions().getClass().getSimpleName());
}

if (this.defaultOptions != null) {

Set<String> defaultEnableFunctions = this.handleFunctionCallbackConfigurations(this.defaultOptions,
!IS_RUNTIME_CALL);

functionsForThisRequest.addAll(defaultEnableFunctions);

request = ModelOptionsUtils.merge(this.defaultOptions, request, WenxinApi.ChatCompletionRequest.class);
}

if (!CollectionUtils.isEmpty(functionsForThisRequest)) {
request = ModelOptionsUtils.merge(
WenxinChatOptions.builder().withTools(this.getFunctionTools(functionsForThisRequest)).build(),
request, WenxinApi.ChatCompletionRequest.class);
}

return request;
}

private List<WenxinApi.FunctionTool> getFunctionTools(Set<String> functionNames) {
return this.resolveFunctionCallbacks(functionNames).stream().map(functioncallback -> {
var function = new WenxinApi.FunctionTool(functioncallback.getName(), functioncallback.getDescription(),
functioncallback.getInputTypeSchema(), null);
return function;
}).toList();
}

@Override
protected WenxinApi.ChatCompletionRequest doCreateToolResponseRequest(
WenxinApi.ChatCompletionRequest previousRequest, WenxinApi.ChatCompletionMessage responseMessage,
List<WenxinApi.ChatCompletionMessage> conversationHistory) {

var functionName = responseMessage.functionCall().name();
String functionArguments = responseMessage.functionCall().arguments();
if (!this.functionCallbackRegister.containsKey(functionName)) {
throw new IllegalStateException("Function callback not found for function name: " + functionName);
}

String functionResponse = this.functionCallbackRegister.get(functionName).call(functionArguments);

conversationHistory.add(
new WenxinApi.ChatCompletionMessage(functionResponse, WenxinApi.Role.FUNCTION, functionName, null));

WenxinApi.ChatCompletionRequest newRequest = new WenxinApi.ChatCompletionRequest(conversationHistory, false);
newRequest = ModelOptionsUtils.merge(newRequest, previousRequest, WenxinApi.ChatCompletionRequest.class);
return newRequest;
}

@Override
protected List<WenxinApi.ChatCompletionMessage> doGetUserMessages(WenxinApi.ChatCompletionRequest request) {
return request.messages();
}

@Override
protected WenxinApi.ChatCompletionMessage doGetToolResponseMessage(
ResponseEntity<WenxinApi.ChatCompletion> chatCompletion) {
return new WenxinApi.ChatCompletionMessage(chatCompletion.getBody().result(), WenxinApi.Role.ASSISTANT, null,
chatCompletion.getBody().functionCall());
}

@Override
protected ResponseEntity<WenxinApi.ChatCompletion> doChatCompletion(WenxinApi.ChatCompletionRequest request) {
return this.wenxinApi.chatCompletionEntity(request);
}

@Override
protected Flux<ResponseEntity<WenxinApi.ChatCompletion>> doChatCompletionStream(
WenxinApi.ChatCompletionRequest request) {
return this.wenxinApi.chatCompletionStream(request)
.map(this::chunkToChatCompletion)
.map(Optional::ofNullable)
.map(ResponseEntity::of);
}

@Override
protected boolean isToolFunctionCall(ResponseEntity<WenxinApi.ChatCompletion> chatCompletion) {
var body = chatCompletion.getBody();
if (body == null) {
return false;
}
return body.finishReason() == WenxinApi.ChatCompletionFinishReason.FUNCTION_CALL;
}
// @formatter:on

}
Loading