Skip to content

Commit 249be21

Browse files
committed
feat: add solar chat model
1 parent a75effe commit 249be21

File tree

7 files changed

+805
-1
lines changed

7 files changed

+805
-1
lines changed

models/spring-ai-solar/src/main/java/org/springframework/ai/solar/api/SolarApi.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ public record ChatCompletion(@JsonProperty("id") String id,
324324
@JsonProperty("object") String object,
325325
@JsonProperty("created") Long created,
326326
@JsonProperty("model") String model,
327-
@JsonProperty("system_fingerprint") String systemFingerprint,
327+
@JsonProperty("system_fingerprint") Object systemFingerprint,
328328
@JsonProperty("choices") List<Choice> choices,
329329
@JsonProperty("usage") Usage usage) {
330330
/**
Lines changed: 280 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
package org.springframework.ai.solar.chat;
2+
3+
import java.util.Collections;
4+
import java.util.List;
5+
import java.util.Map;
6+
7+
import org.slf4j.Logger;
8+
import org.slf4j.LoggerFactory;
9+
import org.springframework.ai.chat.messages.AssistantMessage;
10+
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
11+
import org.springframework.ai.chat.metadata.EmptyUsage;
12+
import org.springframework.ai.chat.model.ChatModel;
13+
import org.springframework.ai.chat.model.ChatResponse;
14+
import org.springframework.ai.chat.model.Generation;
15+
import org.springframework.ai.chat.model.MessageAggregator;
16+
import org.springframework.ai.chat.model.StreamingChatModel;
17+
import org.springframework.ai.chat.observation.ChatModelObservationContext;
18+
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
19+
import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
20+
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
21+
import org.springframework.ai.chat.prompt.ChatOptions;
22+
import org.springframework.ai.chat.prompt.ChatOptionsBuilder;
23+
import org.springframework.ai.chat.prompt.Prompt;
24+
import org.springframework.ai.model.ModelOptionsUtils;
25+
import org.springframework.ai.retry.RetryUtils;
26+
import org.springframework.ai.solar.api.SolarApi;
27+
import org.springframework.ai.solar.api.SolarApi.ChatCompletion;
28+
import org.springframework.ai.solar.api.SolarApi.ChatCompletionChunk;
29+
import org.springframework.ai.solar.api.SolarApi.ChatCompletionMessage;
30+
import org.springframework.ai.solar.api.SolarApi.ChatCompletionMessage.Role;
31+
import org.springframework.ai.solar.api.SolarApi.ChatCompletionRequest;
32+
import org.springframework.ai.solar.api.SolarConstants;
33+
import org.springframework.ai.solar.metadata.SolarUsage;
34+
import org.springframework.http.ResponseEntity;
35+
import org.springframework.retry.support.RetryTemplate;
36+
import org.springframework.util.Assert;
37+
38+
import io.micrometer.observation.Observation;
39+
import io.micrometer.observation.ObservationRegistry;
40+
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
41+
import reactor.core.publisher.Flux;
42+
import reactor.core.publisher.Mono;
43+
44+
public class SolarChatModel implements ChatModel, StreamingChatModel {
45+
46+
private static final Logger logger = LoggerFactory.getLogger(SolarChatModel.class);
47+
48+
private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention();
49+
50+
/**
51+
* The retry template used to retry the Solar API calls.
52+
*/
53+
public final RetryTemplate retryTemplate;
54+
55+
/**
56+
* The default options used for the chat completion requests.
57+
*/
58+
private final SolarChatOptions defaultOptions;
59+
60+
/**
61+
* Low-level access to the Solar API.
62+
*/
63+
private final SolarApi solarApi;
64+
65+
/**
66+
* Observation registry used for instrumentation.
67+
*/
68+
private final ObservationRegistry observationRegistry;
69+
70+
/**
71+
* Conventions to use for generating observations.
72+
*/
73+
private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;
74+
75+
/**
76+
* Creates an instance of the SolarChatModel.
77+
* @param SolarApi The SolarApi instance to be used for interacting with the Solar
78+
* Chat API.
79+
* @throws IllegalArgumentException if SolarApi is null
80+
*/
81+
public SolarChatModel(SolarApi SolarApi) {
82+
this(SolarApi, SolarChatOptions.builder().withModel(SolarApi.DEFAULT_CHAT_MODEL).withTemperature(0.7).build());
83+
}
84+
85+
/**
86+
* Initializes an instance of the SolarChatModel.
87+
*
88+
* @param SolarApi The SolarApi instance to be used for interacting with the Solar Chat API.
89+
* @param options The SolarChatOptions to configure the chat client.
90+
*/
91+
public SolarChatModel(SolarApi SolarApi, SolarChatOptions options) {
92+
this(SolarApi, options, RetryUtils.DEFAULT_RETRY_TEMPLATE);
93+
}
94+
95+
/**
96+
* Initializes a new instance of the SolarChatModel.
97+
*
98+
* @param SolarApi The SolarApi instance to be used for interacting with the Solar Chat API.
99+
* @param options The SolarChatOptions to configure the chat client.
100+
* @param retryTemplate The retry template.
101+
*/
102+
public SolarChatModel(SolarApi SolarApi, SolarChatOptions options, RetryTemplate retryTemplate) {
103+
this(SolarApi, options, retryTemplate, ObservationRegistry.NOOP);
104+
}
105+
106+
/**
107+
* Initializes a new instance of the SolarChatModel.
108+
*
109+
* @param SolarApi The SolarApi instance to be used for interacting with the Solar Chat API.
110+
* @param options The SolarChatOptions to configure the chat client.
111+
* @param retryTemplate The retry template.
112+
* @param observationRegistry The ObservationRegistry used for instrumentation.
113+
*/
114+
public SolarChatModel(SolarApi SolarApi, SolarChatOptions options, RetryTemplate retryTemplate,
115+
ObservationRegistry observationRegistry) {
116+
Assert.notNull(SolarApi, "SolarApi must not be null");
117+
Assert.notNull(options, "Options must not be null");
118+
Assert.notNull(retryTemplate, "RetryTemplate must not be null");
119+
Assert.notNull(observationRegistry, "ObservationRegistry must not be null");
120+
this.solarApi = SolarApi;
121+
this.defaultOptions = options;
122+
this.retryTemplate = retryTemplate;
123+
this.observationRegistry = observationRegistry;
124+
}
125+
126+
@Override
127+
public ChatResponse call(Prompt prompt) {
128+
ChatCompletionRequest request = createRequest(prompt, false);
129+
130+
ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
131+
.prompt(prompt)
132+
.provider(SolarConstants.PROVIDER_NAME)
133+
.requestOptions(buildRequestOptions(request))
134+
.build();
135+
136+
return ChatModelObservationDocumentation.CHAT_MODEL_OPERATION
137+
.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
138+
this.observationRegistry)
139+
.observe(() -> {
140+
ResponseEntity<SolarApi.ChatCompletion> completionEntity = this.retryTemplate
141+
.execute(ctx -> this.solarApi.chatCompletionEntity(request));
142+
143+
var chatCompletion = completionEntity.getBody();
144+
if (chatCompletion == null) {
145+
logger.warn("No chat completion returned for prompt: {}", prompt);
146+
return new ChatResponse(List.of());
147+
}
148+
149+
// @formatter:off
150+
Map<String, Object> metadata = Map.of(
151+
"id", chatCompletion.id(),
152+
"role", SolarApi.ChatCompletionMessage.Role.ASSISTANT
153+
);
154+
// @formatter:on
155+
156+
var assistantMessage = new AssistantMessage(chatCompletion.choices().get(0).message().content(),
157+
metadata);
158+
List<Generation> generations = Collections.singletonList(new Generation(assistantMessage));
159+
ChatResponse chatResponse = new ChatResponse(generations, from(chatCompletion, request.model()));
160+
observationContext.setResponse(chatResponse);
161+
return chatResponse;
162+
});
163+
}
164+
165+
@Override
166+
public Flux<ChatResponse> stream(Prompt prompt) {
167+
return Flux.deferContextual(contextView -> {
168+
ChatCompletionRequest request = createRequest(prompt, true);
169+
170+
var completionChunks = this.solarApi.chatCompletionStream(request);
171+
172+
final ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
173+
.prompt(prompt)
174+
.provider(SolarConstants.PROVIDER_NAME)
175+
.requestOptions(buildRequestOptions(request))
176+
.build();
177+
178+
Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(
179+
this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
180+
this.observationRegistry);
181+
182+
observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start();
183+
184+
Flux<ChatResponse> chatResponse = completionChunks.map(this::toChatCompletion)
185+
.switchMap(chatCompletion -> Mono.just(chatCompletion).map(chatCompletion2 -> {
186+
// @formatter:off
187+
Map<String, Object> metadata = Map.of(
188+
"id", chatCompletion.id(),
189+
"role", Role.ASSISTANT
190+
);
191+
// @formatter:on
192+
193+
var assistantMessage = new AssistantMessage(chatCompletion.choices().get(0).delta().content(), metadata);
194+
List<Generation> generations = Collections.singletonList(new Generation(assistantMessage));
195+
return new ChatResponse(generations, from(chatCompletion, request.model()));
196+
}))
197+
.doOnError(observation::error)
198+
.doFinally(s -> observation.stop())
199+
.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));
200+
return new MessageAggregator().aggregate(chatResponse, observationContext::setResponse);
201+
202+
});
203+
}
204+
205+
private ChatCompletionChunk toChatCompletion(ChatCompletionChunk chunk) {
206+
return new ChatCompletionChunk(chunk.id(), chunk.object(), chunk.created(), chunk.model(),
207+
chunk.systemFingerprint(), chunk.choices());
208+
}
209+
210+
public ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
211+
var chatCompletionMessages = prompt.getInstructions()
212+
.stream()
213+
.map(m -> new ChatCompletionMessage(m.getContent(),
214+
ChatCompletionMessage.Role.valueOf(m.getMessageType().name())))
215+
.toList();
216+
var systemMessageList = chatCompletionMessages.stream()
217+
.filter(msg -> msg.role() == ChatCompletionMessage.Role.SYSTEM)
218+
.toList();
219+
var userMessageList = chatCompletionMessages.stream()
220+
.filter(msg -> msg.role() != ChatCompletionMessage.Role.SYSTEM)
221+
.toList();
222+
223+
if (systemMessageList.size() > 1) {
224+
throw new IllegalArgumentException("Only one system message is allowed in the prompt");
225+
}
226+
227+
var systemMessage = systemMessageList.isEmpty() ? null : systemMessageList.get(0).content();
228+
229+
var request = new SolarApi.ChatCompletionRequest(userMessageList, systemMessage, stream);
230+
231+
if (this.defaultOptions != null) {
232+
request = ModelOptionsUtils.merge(this.defaultOptions, request, SolarApi.ChatCompletionRequest.class);
233+
}
234+
235+
if (prompt.getOptions() != null) {
236+
var updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class,
237+
SolarChatOptions.class);
238+
request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, SolarApi.ChatCompletionRequest.class);
239+
}
240+
return request;
241+
}
242+
243+
@Override
244+
public ChatOptions getDefaultOptions() {
245+
return SolarChatOptions.fromOptions(this.defaultOptions);
246+
}
247+
248+
private ChatOptions buildRequestOptions(ChatCompletionRequest request) {
249+
return ChatOptionsBuilder.builder()
250+
.withModel(request.model())
251+
.withMaxTokens(request.maxTokens())
252+
.withTemperature(request.temperature())
253+
.withTopP(request.topP())
254+
.build();
255+
}
256+
257+
private ChatResponseMetadata from(ChatCompletion result, String model) {
258+
Assert.notNull(result, "Solar ChatCompletionResult must not be null");
259+
return ChatResponseMetadata.builder()
260+
.withId(result.id() != null ? result.id() : "")
261+
.withUsage(result.usage() != null ? SolarUsage.from(result.usage()) : new EmptyUsage())
262+
.withModel(model)
263+
.withKeyValue("created", result.created() != null ? result.created() : 0L)
264+
.build();
265+
}
266+
267+
private ChatResponseMetadata from(ChatCompletionChunk result, String model) {
268+
Assert.notNull(result, "Solar ChatCompletionResult must not be null");
269+
return ChatResponseMetadata.builder()
270+
.withId(result.id() != null ? result.id() : "")
271+
.withModel(model)
272+
.withKeyValue("created", result.created() != null ? result.created() : 0L)
273+
.build();
274+
}
275+
276+
public void setObservationConvention(ChatModelObservationConvention observationConvention) {
277+
this.observationConvention = observationConvention;
278+
}
279+
280+
}

0 commit comments

Comments
 (0)