Skip to content

Commit d5b8123

Browse files
mxsl-grmarkpollack
authored andcommitted
Add support for Moonshot AI model
- Docs - Tests
1 parent ae76407 commit d5b8123

File tree

38 files changed

+3704
-1
lines changed

38 files changed

+3704
-1
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ Spring AI supports many AI models. For an overview see here. Specific models c
9898
* Transformers (ONNX)
9999
* Anthropic Claude3
100100
* MiniMax
101+
* Moonshot
101102

102103

103104
**Prompts:** Central to AI model interaction is the Prompt, which provides specific instructions for the AI to act upon.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
[Moonshot Chat Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/moonshot-chat.html)

models/spring-ai-moonshot/pom.xml

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
<?xml version="1.0" encoding="UTF-8"?>
2+
<project xmlns="http://maven.apache.org/POM/4.0.0"
3+
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
4+
<modelVersion>4.0.0</modelVersion>
5+
<parent>
6+
<groupId>org.springframework.ai</groupId>
7+
<artifactId>spring-ai</artifactId>
8+
<version>1.0.0-SNAPSHOT</version>
9+
<relativePath>../../pom.xml</relativePath>
10+
</parent>
11+
<artifactId>spring-ai-moonshot</artifactId>
12+
<packaging>jar</packaging>
13+
<name>Spring AI Moonshot</name>
14+
<description>Moonshot support</description>
15+
<url>https://github.com/spring-projects/spring-ai</url>
16+
17+
<scm>
18+
<url>https://github.com/spring-projects/spring-ai</url>
19+
<connection>git://github.com/spring-projects/spring-ai.git</connection>
20+
<developerConnection>[email protected]:spring-projects/spring-ai.git</developerConnection>
21+
</scm>
22+
23+
<dependencies>
24+
25+
<!-- production dependencies -->
26+
<dependency>
27+
<groupId>org.springframework.ai</groupId>
28+
<artifactId>spring-ai-core</artifactId>
29+
<version>${project.parent.version}</version>
30+
</dependency>
31+
32+
<dependency>
33+
<groupId>org.springframework.ai</groupId>
34+
<artifactId>spring-ai-retry</artifactId>
35+
<version>${project.parent.version}</version>
36+
</dependency>
37+
38+
<!-- Spring Framework -->
39+
<dependency>
40+
<groupId>org.springframework</groupId>
41+
<artifactId>spring-context-support</artifactId>
42+
</dependency>
43+
44+
<dependency>
45+
<groupId>org.springframework.boot</groupId>
46+
<artifactId>spring-boot-starter-logging</artifactId>
47+
</dependency>
48+
49+
<!-- test dependencies -->
50+
<dependency>
51+
<groupId>org.springframework.ai</groupId>
52+
<artifactId>spring-ai-test</artifactId>
53+
<version>${project.version}</version>
54+
<scope>test</scope>
55+
</dependency>
56+
57+
</dependencies>
58+
59+
</project>
Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
/*
2+
* Copyright 2023 - 2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.moonshot;
17+
18+
import org.slf4j.Logger;
19+
import org.slf4j.LoggerFactory;
20+
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
21+
import org.springframework.ai.chat.model.ChatModel;
22+
import org.springframework.ai.chat.model.ChatResponse;
23+
import org.springframework.ai.chat.model.Generation;
24+
import org.springframework.ai.chat.model.StreamingChatModel;
25+
import org.springframework.ai.chat.prompt.ChatOptions;
26+
import org.springframework.ai.chat.prompt.Prompt;
27+
import org.springframework.ai.model.ModelOptionsUtils;
28+
import org.springframework.ai.moonshot.api.MoonshotApi;
29+
import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletion;
30+
import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletion.Choice;
31+
import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletionChunk;
32+
import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletionRequest;
33+
import org.springframework.ai.retry.RetryUtils;
34+
import org.springframework.http.ResponseEntity;
35+
import org.springframework.retry.support.RetryTemplate;
36+
import org.springframework.util.Assert;
37+
import reactor.core.publisher.Flux;
38+
39+
import java.util.HashMap;
40+
import java.util.List;
41+
import java.util.Map;
42+
import java.util.concurrent.ConcurrentHashMap;
43+
44+
/**
45+
* @author Geng Rong
46+
*/
47+
public class MoonshotChatModel implements ChatModel, StreamingChatModel {
48+
49+
private static final Logger logger = LoggerFactory.getLogger(MoonshotChatModel.class);
50+
51+
/**
52+
* The default options used for the chat completion requests.
53+
*/
54+
private final MoonshotChatOptions defaultOptions;
55+
56+
/**
57+
* Low-level access to the Moonshot API.
58+
*/
59+
private final MoonshotApi moonshotApi;
60+
61+
private final RetryTemplate retryTemplate;
62+
63+
/**
64+
* Initializes a new instance of the MoonshotChatModel.
65+
* @param moonshotApi The Moonshot instance to be used for interacting with the
66+
* Moonshot Chat API.
67+
*/
68+
public MoonshotChatModel(MoonshotApi moonshotApi) {
69+
this(moonshotApi, MoonshotChatOptions.builder().withModel(MoonshotApi.DEFAULT_CHAT_MODEL).build());
70+
}
71+
72+
/**
73+
* Initializes a new instance of the MoonshotChatModel.
74+
* @param moonshotApi The Moonshot instance to be used for interacting with the
75+
* Moonshot Chat API.
76+
* @param options The MoonshotChatOptions to configure the chat client.
77+
*/
78+
public MoonshotChatModel(MoonshotApi moonshotApi, MoonshotChatOptions options) {
79+
this(moonshotApi, options, RetryUtils.DEFAULT_RETRY_TEMPLATE);
80+
}
81+
82+
/**
83+
* Initializes a new instance of the MoonshotChatModel.
84+
* @param moonshotApi The Moonshot instance to be used for interacting with the
85+
* Moonshot Chat API.
86+
* @param options The MoonshotChatOptions to configure the chat client.
87+
* @param retryTemplate The retry template.
88+
*/
89+
public MoonshotChatModel(MoonshotApi moonshotApi, MoonshotChatOptions options, RetryTemplate retryTemplate) {
90+
Assert.notNull(moonshotApi, "MoonshotApi must not be null");
91+
Assert.notNull(options, "Options must not be null");
92+
Assert.notNull(retryTemplate, "RetryTemplate must not be null");
93+
this.moonshotApi = moonshotApi;
94+
this.defaultOptions = options;
95+
this.retryTemplate = retryTemplate;
96+
}
97+
98+
@Override
99+
public ChatResponse call(Prompt prompt) {
100+
ChatCompletionRequest request = createRequest(prompt, false);
101+
102+
return this.retryTemplate.execute(ctx -> {
103+
104+
ResponseEntity<ChatCompletion> completionEntity = this.doChatCompletion(request);
105+
106+
var chatCompletion = completionEntity.getBody();
107+
if (chatCompletion == null) {
108+
logger.warn("No chat completion returned for prompt: {}", prompt);
109+
return new ChatResponse(List.of());
110+
}
111+
112+
List<Generation> generations = chatCompletion.choices()
113+
.stream()
114+
.map(choice -> new Generation(choice.message().content(), toMap(chatCompletion.id(), choice))
115+
.withGenerationMetadata(ChatGenerationMetadata.from(choice.finishReason().name(), null)))
116+
.toList();
117+
118+
return new ChatResponse(generations);
119+
});
120+
}
121+
122+
@Override
123+
public ChatOptions getDefaultOptions() {
124+
return null;
125+
}
126+
127+
@Override
128+
public Flux<ChatResponse> stream(Prompt prompt) {
129+
var request = createRequest(prompt, true);
130+
131+
return retryTemplate.execute(ctx -> {
132+
var completionChunks = this.moonshotApi.chatCompletionStream(request);
133+
ConcurrentHashMap<String, String> roleMap = new ConcurrentHashMap<>();
134+
135+
return completionChunks.map(this::toChatCompletion).map(chatCompletion -> {
136+
String id = chatCompletion.id();
137+
138+
List<Generation> generations = chatCompletion.choices().stream().map(choice -> {
139+
if (choice.message().role() != null) {
140+
roleMap.putIfAbsent(id, choice.message().role().name());
141+
}
142+
String finish = (choice.finishReason() != null ? choice.finishReason().name() : "");
143+
var generation = new Generation(choice.message().content(),
144+
Map.of("id", id, "role", roleMap.get(id), "finishReason", finish));
145+
if (choice.finishReason() != null) {
146+
generation = generation
147+
.withGenerationMetadata(ChatGenerationMetadata.from(choice.finishReason().name(), null));
148+
}
149+
return generation;
150+
}).toList();
151+
return new ChatResponse(generations);
152+
});
153+
});
154+
155+
}
156+
157+
private Map<String, Object> toMap(String id, ChatCompletion.Choice choice) {
158+
Map<String, Object> map = new HashMap<>();
159+
160+
var message = choice.message();
161+
if (message.role() != null) {
162+
map.put("role", message.role().name());
163+
}
164+
if (choice.finishReason() != null) {
165+
map.put("finishReason", choice.finishReason().name());
166+
}
167+
map.put("id", id);
168+
return map;
169+
}
170+
171+
private ChatCompletion toChatCompletion(ChatCompletionChunk chunk) {
172+
List<Choice> choices = chunk.choices()
173+
.stream()
174+
.map(cc -> new Choice(cc.index(), cc.delta(), cc.finishReason()))
175+
.toList();
176+
177+
return new ChatCompletion(chunk.id(), "chat.completion", chunk.created(), chunk.model(), choices, null);
178+
}
179+
180+
/**
181+
* Accessible for testing.
182+
*/
183+
public MoonshotApi.ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
184+
var chatCompletionMessages = prompt.getInstructions()
185+
.stream()
186+
.map(m -> new MoonshotApi.ChatCompletionMessage(m.getContent(),
187+
MoonshotApi.ChatCompletionMessage.Role.valueOf(m.getMessageType().name())))
188+
.toList();
189+
190+
var request = new MoonshotApi.ChatCompletionRequest(chatCompletionMessages, stream);
191+
192+
if (this.defaultOptions != null) {
193+
request = ModelOptionsUtils.merge(request, this.defaultOptions, MoonshotApi.ChatCompletionRequest.class);
194+
}
195+
196+
if (prompt.getOptions() != null) {
197+
if (prompt.getOptions() instanceof ChatOptions) {
198+
var updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class,
199+
MoonshotChatOptions.class);
200+
request = ModelOptionsUtils.merge(updatedRuntimeOptions, request,
201+
MoonshotApi.ChatCompletionRequest.class);
202+
}
203+
else {
204+
throw new IllegalArgumentException("Prompt options are not of type ChatOptions: "
205+
+ prompt.getOptions().getClass().getSimpleName());
206+
}
207+
}
208+
return request;
209+
}
210+
211+
protected ResponseEntity<ChatCompletion> doChatCompletion(ChatCompletionRequest request) {
212+
return this.moonshotApi.chatCompletionEntity(request);
213+
}
214+
215+
}

0 commit comments

Comments
 (0)