Skip to content

Commit 2f69cf0

Browse files
committed
Add support for Solar AI models
1 parent 5e1f681 commit 2f69cf0

File tree

22 files changed

+2182
-2
lines changed

22 files changed

+2182
-2
lines changed

models/spring-ai-solar/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[Solar Chat Documentation](https://console.upstage.ai/docs/capabilities/chat)
2+
3+
[Solar Embedding Documentation](https://console.upstage.ai/docs/capabilities/embeddings)

models/spring-ai-solar/pom.xml

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
<?xml version="1.0" encoding="UTF-8"?>
2+
<!--
3+
~ Copyright 2023-2024 the original author or authors.
4+
~
5+
~ Licensed under the Apache License, Version 2.0 (the "License");
6+
~ you may not use this file except in compliance with the License.
7+
~ You may obtain a copy of the License at
8+
~
9+
~ https://www.apache.org/licenses/LICENSE-2.0
10+
~
11+
~ Unless required by applicable law or agreed to in writing, software
12+
~ distributed under the License is distributed on an "AS IS" BASIS,
13+
~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
~ See the License for the specific language governing permissions and
15+
~ limitations under the License.
16+
-->
17+
18+
<project xmlns="http://maven.apache.org/POM/4.0.0"
19+
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
20+
<modelVersion>4.0.0</modelVersion>
21+
<parent>
22+
<groupId>org.springframework.ai</groupId>
23+
<artifactId>spring-ai</artifactId>
24+
<version>1.0.0-SNAPSHOT</version>
25+
<relativePath>../../pom.xml</relativePath>
26+
</parent>
27+
<artifactId>spring-ai-solar</artifactId>
28+
<packaging>jar</packaging>
29+
<name>Spring AI Solar</name>
30+
<description>Upstage Solar support</description>
31+
<url>https://github.com/spring-projects/spring-ai</url>
32+
33+
<scm>
34+
<url>https://github.com/spring-projects/spring-ai</url>
35+
<connection>git://github.com/spring-projects/spring-ai.git</connection>
36+
<developerConnection>[email protected]:spring-projects/spring-ai.git</developerConnection>
37+
</scm>
38+
39+
<properties>
40+
</properties>
41+
42+
<dependencies>
43+
44+
<!-- production dependencies -->
45+
<dependency>
46+
<groupId>org.springframework.ai</groupId>
47+
<artifactId>spring-ai-core</artifactId>
48+
<version>${project.parent.version}</version>
49+
</dependency>
50+
51+
<dependency>
52+
<groupId>org.springframework.ai</groupId>
53+
<artifactId>spring-ai-retry</artifactId>
54+
<version>${project.parent.version}</version>
55+
</dependency>
56+
57+
<!-- Spring Framework -->
58+
<dependency>
59+
<groupId>org.springframework</groupId>
60+
<artifactId>spring-context-support</artifactId>
61+
</dependency>
62+
63+
<dependency>
64+
<groupId>org.springframework.boot</groupId>
65+
<artifactId>spring-boot-starter-logging</artifactId>
66+
</dependency>
67+
68+
<!-- test dependencies -->
69+
<dependency>
70+
<groupId>org.springframework.ai</groupId>
71+
<artifactId>spring-ai-test</artifactId>
72+
<version>${project.version}</version>
73+
<scope>test</scope>
74+
</dependency>
75+
76+
<dependency>
77+
<groupId>io.micrometer</groupId>
78+
<artifactId>micrometer-observation-test</artifactId>
79+
<scope>test</scope>
80+
</dependency>
81+
<dependency>
82+
<groupId>org.springframework.ai</groupId>
83+
<artifactId>spring-ai-solar</artifactId>
84+
<version>1.0.0-SNAPSHOT</version>
85+
<scope>test</scope>
86+
</dependency>
87+
88+
</dependencies>
89+
90+
</project>
Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
package org.springframework.ai.solar;
2+
3+
import java.util.Collections;
4+
import java.util.List;
5+
import java.util.Map;
6+
7+
import org.slf4j.Logger;
8+
import org.slf4j.LoggerFactory;
9+
import org.springframework.ai.chat.messages.AssistantMessage;
10+
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
11+
import org.springframework.ai.chat.metadata.EmptyUsage;
12+
import org.springframework.ai.chat.model.ChatModel;
13+
import org.springframework.ai.chat.model.ChatResponse;
14+
import org.springframework.ai.chat.model.Generation;
15+
import org.springframework.ai.chat.model.MessageAggregator;
16+
import org.springframework.ai.chat.model.StreamingChatModel;
17+
import org.springframework.ai.chat.observation.ChatModelObservationContext;
18+
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
19+
import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
20+
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
21+
import org.springframework.ai.chat.prompt.ChatOptions;
22+
import org.springframework.ai.chat.prompt.ChatOptionsBuilder;
23+
import org.springframework.ai.chat.prompt.Prompt;
24+
import org.springframework.ai.model.ModelOptionsUtils;
25+
import org.springframework.ai.retry.RetryUtils;
26+
import org.springframework.ai.solar.api.SolarApi;
27+
import org.springframework.ai.solar.api.common.SolarConstants;
28+
import org.springframework.ai.solar.metadata.SolarUsage;
29+
import org.springframework.http.ResponseEntity;
30+
import org.springframework.retry.support.RetryTemplate;
31+
import org.springframework.util.Assert;
32+
33+
import io.micrometer.observation.Observation;
34+
import io.micrometer.observation.ObservationRegistry;
35+
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
36+
import reactor.core.publisher.Flux;
37+
import reactor.core.publisher.Mono;
38+
39+
public class SolarChatModel implements ChatModel, StreamingChatModel {
40+
41+
private static final Logger logger = LoggerFactory.getLogger(SolarChatModel.class);
42+
43+
private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention();
44+
45+
/**
46+
* The retry template used to retry the Solar API calls.
47+
*/
48+
public final RetryTemplate retryTemplate;
49+
50+
/**
51+
* The default options used for the chat completion requests.
52+
*/
53+
private final SolarChatOptions defaultOptions;
54+
55+
/**
56+
* Low-level access to the Solar API.
57+
*/
58+
private final SolarApi solarApi;
59+
60+
/**
61+
* Observation registry used for instrumentation.
62+
*/
63+
private final ObservationRegistry observationRegistry;
64+
65+
/**
66+
* Conventions to use for generating observations.
67+
*/
68+
private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;
69+
70+
/**
71+
* Creates an instance of the SolarChatModel.
72+
* @param SolarApi The SolarApi instance to be used for interacting with the Solar
73+
* Chat API.
74+
* @throws IllegalArgumentException if SolarApi is null
75+
*/
76+
public SolarChatModel(SolarApi SolarApi) {
77+
this(SolarApi, SolarChatOptions.builder().withModel(SolarApi.DEFAULT_CHAT_MODEL).withTemperature(0.7).build());
78+
}
79+
80+
/**
81+
* Initializes an instance of the SolarChatModel.
82+
* @param SolarApi The SolarApi instance to be used for interacting with the Solar
83+
* Chat API.
84+
* @param options The SolarChatOptions to configure the chat client.
85+
*/
86+
public SolarChatModel(SolarApi SolarApi, SolarChatOptions options) {
87+
this(SolarApi, options, RetryUtils.DEFAULT_RETRY_TEMPLATE);
88+
}
89+
90+
/**
91+
* Initializes a new instance of the SolarChatModel.
92+
* @param SolarApi The SolarApi instance to be used for interacting with the Solar
93+
* Chat API.
94+
* @param options The SolarChatOptions to configure the chat client.
95+
* @param retryTemplate The retry template.
96+
*/
97+
public SolarChatModel(SolarApi SolarApi, SolarChatOptions options, RetryTemplate retryTemplate) {
98+
this(SolarApi, options, retryTemplate, ObservationRegistry.NOOP);
99+
}
100+
101+
/**
102+
* Initializes a new instance of the SolarChatModel.
103+
* @param SolarApi The SolarApi instance to be used for interacting with the Solar
104+
* Chat API.
105+
* @param options The SolarChatOptions to configure the chat client.
106+
* @param retryTemplate The retry template.
107+
* @param observationRegistry The ObservationRegistry used for instrumentation.
108+
*/
109+
public SolarChatModel(SolarApi SolarApi, SolarChatOptions options, RetryTemplate retryTemplate,
110+
ObservationRegistry observationRegistry) {
111+
Assert.notNull(SolarApi, "SolarApi must not be null");
112+
Assert.notNull(options, "Options must not be null");
113+
Assert.notNull(retryTemplate, "RetryTemplate must not be null");
114+
Assert.notNull(observationRegistry, "ObservationRegistry must not be null");
115+
this.solarApi = SolarApi;
116+
this.defaultOptions = options;
117+
this.retryTemplate = retryTemplate;
118+
this.observationRegistry = observationRegistry;
119+
}
120+
121+
@Override
122+
public ChatResponse call(Prompt prompt) {
123+
SolarApi.ChatCompletionRequest request = createRequest(prompt, false);
124+
125+
ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
126+
.prompt(prompt)
127+
.provider(SolarConstants.PROVIDER_NAME)
128+
.requestOptions(buildRequestOptions(request))
129+
.build();
130+
131+
return ChatModelObservationDocumentation.CHAT_MODEL_OPERATION
132+
.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
133+
this.observationRegistry)
134+
.observe(() -> {
135+
ResponseEntity<SolarApi.ChatCompletion> completionEntity = this.retryTemplate
136+
.execute(ctx -> this.solarApi.chatCompletionEntity(request));
137+
138+
var chatCompletion = completionEntity.getBody();
139+
if (chatCompletion == null) {
140+
logger.warn("No chat completion returned for prompt: {}", prompt);
141+
return new ChatResponse(List.of());
142+
}
143+
144+
// @formatter:off
145+
Map<String, Object> metadata = Map.of(
146+
"id", chatCompletion.id(),
147+
"role", SolarApi.ChatCompletionMessage.Role.ASSISTANT
148+
);
149+
// @formatter:on
150+
151+
var assistantMessage = new AssistantMessage(chatCompletion.choices().get(0).message().content(),
152+
metadata);
153+
List<Generation> generations = Collections.singletonList(new Generation(assistantMessage));
154+
ChatResponse chatResponse = new ChatResponse(generations, from(chatCompletion, request.model()));
155+
observationContext.setResponse(chatResponse);
156+
return chatResponse;
157+
});
158+
}
159+
160+
/**
161+
* Accessible for testing.
162+
*/
163+
public SolarApi.ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
164+
var chatCompletionMessages = prompt.getInstructions()
165+
.stream()
166+
.map(m -> new SolarApi.ChatCompletionMessage(m.getContent(),
167+
SolarApi.ChatCompletionMessage.Role.valueOf(m.getMessageType().name())))
168+
.toList();
169+
var systemMessageList = chatCompletionMessages.stream()
170+
.filter(msg -> msg.role() == SolarApi.ChatCompletionMessage.Role.SYSTEM)
171+
.toList();
172+
var userMessageList = chatCompletionMessages.stream()
173+
.filter(msg -> msg.role() != SolarApi.ChatCompletionMessage.Role.SYSTEM)
174+
.toList();
175+
176+
if (systemMessageList.size() > 1) {
177+
throw new IllegalArgumentException("Only one system message is allowed in the prompt");
178+
}
179+
180+
var systemMessage = systemMessageList.isEmpty() ? null : systemMessageList.get(0).content();
181+
182+
var request = new SolarApi.ChatCompletionRequest(userMessageList, systemMessage, stream);
183+
184+
if (this.defaultOptions != null) {
185+
request = ModelOptionsUtils.merge(this.defaultOptions, request, SolarApi.ChatCompletionRequest.class);
186+
}
187+
188+
if (prompt.getOptions() != null) {
189+
var updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class,
190+
SolarChatOptions.class);
191+
request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, SolarApi.ChatCompletionRequest.class);
192+
}
193+
return request;
194+
}
195+
196+
@Override
197+
public ChatOptions getDefaultOptions() {
198+
return SolarChatOptions.fromOptions(this.defaultOptions);
199+
}
200+
201+
private ChatOptions buildRequestOptions(SolarApi.ChatCompletionRequest request) {
202+
return ChatOptionsBuilder.builder()
203+
.withModel(request.model())
204+
.withFrequencyPenalty(request.frequencyPenalty())
205+
.withMaxTokens(request.maxTokens())
206+
.withPresencePenalty(request.presencePenalty())
207+
.withStopSequences(request.stop())
208+
.withTemperature(request.temperature())
209+
.withTopP(request.topP())
210+
.build();
211+
}
212+
213+
private ChatResponseMetadata from(SolarApi.ChatCompletion result, String model) {
214+
Assert.notNull(result, "Solar ChatCompletionResult must not be null");
215+
return ChatResponseMetadata.builder()
216+
.withId(result.id() != null ? result.id() : "")
217+
.withUsage(result.usage() != null ? SolarUsage.from(result.usage()) : new EmptyUsage())
218+
.withModel(model)
219+
.withKeyValue("created", result.created() != null ? result.created() : 0L)
220+
.build();
221+
}
222+
223+
public void setObservationConvention(ChatModelObservationConvention observationConvention) {
224+
this.observationConvention = observationConvention;
225+
}
226+
227+
}

0 commit comments

Comments
 (0)