Skip to content

Commit 997e01c

Browse files
mxsl-grmarkpollack
authored andcommitted
Add support for QianFan AI models
- Add chat, embeddin and image models - Tests - Docs
1 parent db53166 commit 997e01c

File tree

43 files changed

+4352
-2
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+4352
-2
lines changed

models/spring-ai-qianfan/README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
[QianFan Chat Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/qianfan-chat.html)
2+
3+
[QianFan Embedding Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/embeddings/qianfan-embeddings.html)
4+
5+
[QianFan Image Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/image/qianfan-image.html)

models/spring-ai-qianfan/pom.xml

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
<?xml version="1.0" encoding="UTF-8"?>
2+
<project xmlns="http://maven.apache.org/POM/4.0.0"
3+
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
4+
<modelVersion>4.0.0</modelVersion>
5+
<parent>
6+
<groupId>org.springframework.ai</groupId>
7+
<artifactId>spring-ai</artifactId>
8+
<version>1.0.0-SNAPSHOT</version>
9+
<relativePath>../../pom.xml</relativePath>
10+
</parent>
11+
<artifactId>spring-ai-qianfan</artifactId>
12+
<packaging>jar</packaging>
13+
<name>Spring AI QianFan</name>
14+
<description>Baidu QianFan support</description>
15+
<url>https://github.com/spring-projects/spring-ai</url>
16+
17+
<scm>
18+
<url>https://github.com/spring-projects/spring-ai</url>
19+
<connection>git://github.com/spring-projects/spring-ai.git</connection>
20+
<developerConnection>[email protected]:spring-projects/spring-ai.git</developerConnection>
21+
</scm>
22+
23+
<dependencies>
24+
25+
<!-- production dependencies -->
26+
<dependency>
27+
<groupId>org.springframework.ai</groupId>
28+
<artifactId>spring-ai-core</artifactId>
29+
<version>${project.parent.version}</version>
30+
</dependency>
31+
32+
<dependency>
33+
<groupId>org.springframework.ai</groupId>
34+
<artifactId>spring-ai-retry</artifactId>
35+
<version>${project.parent.version}</version>
36+
</dependency>
37+
38+
<!-- Spring Framework -->
39+
<dependency>
40+
<groupId>org.springframework</groupId>
41+
<artifactId>spring-context-support</artifactId>
42+
</dependency>
43+
44+
<dependency>
45+
<groupId>org.springframework.boot</groupId>
46+
<artifactId>spring-boot-starter-logging</artifactId>
47+
</dependency>
48+
49+
<!-- test dependencies -->
50+
<dependency>
51+
<groupId>org.springframework.ai</groupId>
52+
<artifactId>spring-ai-test</artifactId>
53+
<version>${project.version}</version>
54+
<scope>test</scope>
55+
</dependency>
56+
57+
</dependencies>
58+
59+
</project>
Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
/*
2+
* Copyright 2023 - 2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.qianfan;
17+
18+
import org.slf4j.Logger;
19+
import org.slf4j.LoggerFactory;
20+
import org.springframework.ai.chat.model.ChatModel;
21+
import org.springframework.ai.chat.model.ChatResponse;
22+
import org.springframework.ai.chat.model.Generation;
23+
import org.springframework.ai.chat.model.StreamingChatModel;
24+
import org.springframework.ai.chat.prompt.ChatOptions;
25+
import org.springframework.ai.chat.prompt.Prompt;
26+
import org.springframework.ai.model.ModelOptionsUtils;
27+
import org.springframework.ai.qianfan.api.QianFanApi;
28+
import org.springframework.ai.qianfan.api.QianFanApi.ChatCompletion;
29+
import org.springframework.ai.qianfan.api.QianFanApi.ChatCompletionChunk;
30+
import org.springframework.ai.qianfan.api.QianFanApi.ChatCompletionMessage;
31+
import org.springframework.ai.qianfan.api.QianFanApi.ChatCompletionMessage.Role;
32+
import org.springframework.ai.qianfan.api.QianFanApi.ChatCompletionRequest;
33+
import org.springframework.ai.retry.RetryUtils;
34+
import org.springframework.http.ResponseEntity;
35+
import org.springframework.retry.support.RetryTemplate;
36+
import org.springframework.util.Assert;
37+
import reactor.core.publisher.Flux;
38+
39+
import java.util.Collections;
40+
import java.util.List;
41+
import java.util.Map;
42+
43+
/**
44+
* {@link ChatModel} and {@link StreamingChatModel} implementation for {@literal QianFan}
45+
* backed by {@link QianFanApi}.
46+
*
47+
* @author Geng Rong
48+
* @since 1.0
49+
* @see ChatModel
50+
* @see StreamingChatModel
51+
* @see QianFanApi
52+
*/
53+
public class QianFanChatModel implements ChatModel, StreamingChatModel {
54+
55+
private static final Logger logger = LoggerFactory.getLogger(QianFanChatModel.class);
56+
57+
/**
58+
* The default options used for the chat completion requests.
59+
*/
60+
private final QianFanChatOptions defaultOptions;
61+
62+
/**
63+
* The retry template used to retry the QianFan API calls.
64+
*/
65+
public final RetryTemplate retryTemplate;
66+
67+
/**
68+
* Low-level access to the QianFan API.
69+
*/
70+
private final QianFanApi qianFanApi;
71+
72+
/**
73+
* Creates an instance of the QianFanChatModel.
74+
* @param qianFanApi The QianFanApi instance to be used for interacting with the
75+
* QianFan Chat API.
76+
* @throws IllegalArgumentException if QianFanApi is null
77+
*/
78+
public QianFanChatModel(QianFanApi qianFanApi) {
79+
this(qianFanApi,
80+
QianFanChatOptions.builder().withModel(QianFanApi.DEFAULT_CHAT_MODEL).withTemperature(0.7f).build());
81+
}
82+
83+
/**
84+
* Initializes an instance of the QianFanChatModel.
85+
* @param qianFanApi The QianFanApi instance to be used for interacting with the
86+
* QianFan Chat API.
87+
* @param options The QianFanChatOptions to configure the chat client.
88+
*/
89+
public QianFanChatModel(QianFanApi qianFanApi, QianFanChatOptions options) {
90+
this(qianFanApi, options, RetryUtils.DEFAULT_RETRY_TEMPLATE);
91+
}
92+
93+
/**
94+
* Initializes a new instance of the QianFanChatModel.
95+
* @param qianFanApi The QianFanApi instance to be used for interacting with the
96+
* QianFan Chat API.
97+
* @param options The QianFanChatOptions to configure the chat client.
98+
* @param retryTemplate The retry template.
99+
*/
100+
public QianFanChatModel(QianFanApi qianFanApi, QianFanChatOptions options, RetryTemplate retryTemplate) {
101+
Assert.notNull(qianFanApi, "QianFanApi must not be null");
102+
Assert.notNull(options, "Options must not be null");
103+
Assert.notNull(retryTemplate, "RetryTemplate must not be null");
104+
this.qianFanApi = qianFanApi;
105+
this.defaultOptions = options;
106+
this.retryTemplate = retryTemplate;
107+
}
108+
109+
@Override
110+
public ChatResponse call(Prompt prompt) {
111+
112+
ChatCompletionRequest request = createRequest(prompt, false);
113+
114+
return this.retryTemplate.execute(ctx -> {
115+
116+
ResponseEntity<ChatCompletion> completionEntity = this.doChatCompletion(request);
117+
118+
var chatCompletion = completionEntity.getBody();
119+
if (chatCompletion == null) {
120+
logger.warn("No chat completion returned for prompt: {}", prompt);
121+
return new ChatResponse(List.of());
122+
}
123+
124+
// if (chatCompletion.baseResponse() != null &&
125+
// chatCompletion.baseResponse().statusCode() != 0) {
126+
// throw new RuntimeException(chatCompletion.baseResponse().message());
127+
// }
128+
129+
var generation = new Generation(chatCompletion.result(),
130+
Map.of("id", chatCompletion.id(), "role", Role.ASSISTANT));
131+
return new ChatResponse(Collections.singletonList(generation));
132+
});
133+
}
134+
135+
@Override
136+
public Flux<ChatResponse> stream(Prompt prompt) {
137+
var request = createRequest(prompt, true);
138+
139+
return retryTemplate.execute(ctx -> {
140+
var completionChunks = this.qianFanApi.chatCompletionStream(request);
141+
142+
return completionChunks.map(this::toChatCompletion).map(chatCompletion -> {
143+
String id = chatCompletion.id();
144+
var generation = new Generation(chatCompletion.result(), Map.of("id", id, "role", Role.ASSISTANT));
145+
return new ChatResponse(Collections.singletonList(generation));
146+
});
147+
});
148+
}
149+
150+
/**
151+
* Convert the ChatCompletionChunk into a ChatCompletion.
152+
* @param chunk the ChatCompletionChunk to convert
153+
* @return the ChatCompletion
154+
*/
155+
private ChatCompletion toChatCompletion(ChatCompletionChunk chunk) {
156+
return new ChatCompletion(chunk.id(), chunk.object(), chunk.created(), chunk.result(), chunk.usage());
157+
}
158+
159+
/**
160+
* Accessible for testing.
161+
*/
162+
public ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
163+
var chatCompletionMessages = prompt.getInstructions()
164+
.stream()
165+
.map(m -> new ChatCompletionMessage(m.getContent(),
166+
ChatCompletionMessage.Role.valueOf(m.getMessageType().name())))
167+
.toList();
168+
var systemMessageList = chatCompletionMessages.stream().filter(msg -> msg.role() == Role.SYSTEM).toList();
169+
170+
if (systemMessageList.size() > 1) {
171+
throw new IllegalArgumentException("Only one system message is allowed in the prompt");
172+
}
173+
174+
var systemMessage = systemMessageList.isEmpty() ? null : systemMessageList.get(0).content();
175+
176+
var request = new ChatCompletionRequest(chatCompletionMessages, systemMessage, stream);
177+
178+
if (this.defaultOptions != null) {
179+
request = ModelOptionsUtils.merge(this.defaultOptions, request, ChatCompletionRequest.class);
180+
}
181+
182+
if (prompt.getOptions() != null) {
183+
if (prompt.getOptions() != null) {
184+
var updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class,
185+
QianFanChatOptions.class);
186+
request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, ChatCompletionRequest.class);
187+
}
188+
else {
189+
throw new IllegalArgumentException("Prompt options are not of type ChatOptions: "
190+
+ prompt.getOptions().getClass().getSimpleName());
191+
}
192+
}
193+
return request;
194+
}
195+
196+
@Override
197+
public ChatOptions getDefaultOptions() {
198+
return QianFanChatOptions.fromOptions(this.defaultOptions);
199+
}
200+
201+
private ResponseEntity<ChatCompletion> doChatCompletion(ChatCompletionRequest request) {
202+
return this.qianFanApi.chatCompletionEntity(request);
203+
}
204+
205+
}

0 commit comments

Comments
 (0)