diff --git a/models/spring-ai-solar/README.md b/models/spring-ai-solar/README.md new file mode 100644 index 00000000000..a758f25d842 --- /dev/null +++ b/models/spring-ai-solar/README.md @@ -0,0 +1,3 @@ +[Solar Chat Documentation](https://console.upstage.ai/docs/capabilities/chat) + +[Solar Embedding Documentation](https://console.upstage.ai/docs/capabilities/embeddings) diff --git a/models/spring-ai-solar/pom.xml b/models/spring-ai-solar/pom.xml new file mode 100644 index 00000000000..8ec94a9afef --- /dev/null +++ b/models/spring-ai-solar/pom.xml @@ -0,0 +1,84 @@ + + + + + 4.0.0 + + org.springframework.ai + spring-ai + 1.0.0-SNAPSHOT + ../../pom.xml + + spring-ai-solar + jar + Spring AI Starter - Solar + Upstage Solar support + https://github.com/spring-projects/spring-ai + + + https://github.com/spring-projects/spring-ai + git://github.com/spring-projects/spring-ai.git + git@github.com:spring-projects/spring-ai.git + + + + + + + + + + org.springframework.ai + spring-ai-core + ${project.parent.version} + + + + org.springframework.ai + spring-ai-retry + ${project.parent.version} + + + + + org.springframework + spring-context-support + + + + org.springframework.boot + spring-boot-starter-logging + + + + + org.springframework.ai + spring-ai-test + ${project.version} + test + + + + io.micrometer + micrometer-observation-test + test + + + + + diff --git a/models/spring-ai-solar/src/main/java/org/springframework/ai/solar/aot/SolarRuntimeHints.java b/models/spring-ai-solar/src/main/java/org/springframework/ai/solar/aot/SolarRuntimeHints.java new file mode 100644 index 00000000000..9986709eaf9 --- /dev/null +++ b/models/spring-ai-solar/src/main/java/org/springframework/ai/solar/aot/SolarRuntimeHints.java @@ -0,0 +1,47 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.solar.aot; + +import static org.springframework.ai.aot.AiRuntimeHints.*; + +import org.springframework.ai.solar.api.SolarApi; +import org.springframework.aot.hint.MemberCategory; +import org.springframework.aot.hint.RuntimeHints; +import org.springframework.aot.hint.RuntimeHintsRegistrar; +import org.springframework.lang.NonNull; +import org.springframework.lang.Nullable; + +/** + * The SolarRuntimeHints class is responsible for registering runtime hints for Solar API + * classes. + * + * @author Seunghyeon Ji + */ +public class SolarRuntimeHints implements RuntimeHintsRegistrar { + + @Override + public void registerHints(@NonNull RuntimeHints hints, @Nullable ClassLoader classLoader) { + var mcs = MemberCategory.values(); + for (var tr : findJsonAnnotatedClassesInPackage(SolarApi.class)) { + hints.reflection().registerType(tr, mcs); + } + for (var tr : findJsonAnnotatedClassesInPackage(SolarApi.class)) { + hints.reflection().registerType(tr, mcs); + } + } + +} diff --git a/models/spring-ai-solar/src/main/java/org/springframework/ai/solar/api/SolarApi.java b/models/spring-ai-solar/src/main/java/org/springframework/ai/solar/api/SolarApi.java new file mode 100644 index 00000000000..86fbdb3eb6b --- /dev/null +++ b/models/spring-ai-solar/src/main/java/org/springframework/ai/solar/api/SolarApi.java @@ -0,0 +1,456 @@ +package org.springframework.ai.solar.api; + +import java.util.List; +import java.util.function.Consumer; +import java.util.function.Predicate; + +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.util.Assert; +import org.springframework.web.client.ResponseErrorHandler; +import org.springframework.web.client.RestClient; +import org.springframework.web.reactive.function.client.WebClient; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +public class SolarApi { + + public static final String DEFAULT_CHAT_MODEL = ChatModel.SOLAR_PRO.getValue(); + + public static final String DEFAULT_EMBEDDING_MODEL = EmbeddingModel.EMBEDDING_QUERY.getValue(); + + private static final Predicate SSE_DONE_PREDICATE = "[DONE]"::equals; + + private final RestClient restClient; + + private final WebClient webClient; + + /** + * Create a new chat completion api with default base URL. + * @param apiKey Solar api key. + */ + public SolarApi(String apiKey) { + this(SolarConstants.DEFAULT_BASE_URL, apiKey); + } + + /** + * Create a new chat completion api. + * @param baseUrl api base URL. + * @param apiKey Solar api key. + */ + public SolarApi(String baseUrl, String apiKey) { + this(baseUrl, apiKey, RestClient.builder()); + } + + /** + * Create a new chat completion api. + * @param baseUrl api base URL. + * @param apiKey Solar api key. + * @param restClientBuilder RestClient builder. + */ + public SolarApi(String baseUrl, String apiKey, RestClient.Builder restClientBuilder) { + this(baseUrl, apiKey, restClientBuilder, RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER); + } + + /** + * Create a new chat completion api. + * @param baseUrl api base URL. + * @param apiKey Solar api key. + * @param restClientBuilder RestClient builder. + * @param responseErrorHandler Response error handler. + */ + public SolarApi(String baseUrl, String apiKey, RestClient.Builder restClientBuilder, + ResponseErrorHandler responseErrorHandler) { + this(baseUrl, apiKey, restClientBuilder, WebClient.builder(), responseErrorHandler); + } + + /** + * Create a new chat completion api. + * @param baseUrl api base URL. + * @param apiKey Solar api key. + * @param restClientBuilder RestClient builder. + * @param webClientBuilder WebClient builder. + * @param responseErrorHandler Response error handler. + */ + public SolarApi(String baseUrl, String apiKey, RestClient.Builder restClientBuilder, + WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler) { + Consumer finalHeaders = h -> { + h.setBearerAuth(apiKey); + h.setContentType(MediaType.APPLICATION_JSON); + }; + this.restClient = restClientBuilder.baseUrl(baseUrl) + .defaultHeaders(finalHeaders) + .defaultStatusHandler(responseErrorHandler) + .build(); + this.webClient = webClientBuilder.baseUrl(baseUrl).defaultHeaders(finalHeaders).build(); + } + + /** + * Creates a model response for the given chat conversation. + * @param chatRequest The chat completion request. + * @return Entity response with {@link ChatCompletion} as a body and HTTP status code + * and headers. + */ + public ResponseEntity chatCompletionEntity(ChatCompletionRequest chatRequest) { + Assert.notNull(chatRequest, "The request body can not be null."); + Assert.isTrue(!chatRequest.stream(), "Request must set the stream property to false."); + + return this.restClient.post() + .uri("/v1/solar/chat/completions") + .body(chatRequest) + .retrieve() + .toEntity(ChatCompletion.class); + } + + /** + * Creates a streaming chat response for the given chat conversation. + * @param chatRequest The chat completion request. Must have the stream property set + * to true. + * @return Returns a {@link Flux} stream from chat completion chunks. + */ + public Flux chatCompletionStream(ChatCompletionRequest chatRequest) { + Assert.notNull(chatRequest, "The request body can not be null."); + Assert.isTrue(chatRequest.stream(), "Request must set the stream property to true."); + + return this.webClient.post() + .uri("/v1/solar/chat/completions", chatRequest.model) + .body(Mono.just(chatRequest), ChatCompletionRequest.class) + .retrieve() + .bodyToFlux(String.class) + .takeUntil(SSE_DONE_PREDICATE) + .filter(SSE_DONE_PREDICATE.negate()) + .map(content -> ModelOptionsUtils.jsonToObject(content, ChatCompletionChunk.class)); + } + + /** + * Solar Chat Completion Models: + * Solar Model. + */ + public enum ChatModel { + + SOLAR_PRO("solar-pro"), SOLAR_MINI("solar-mini"), SOLAR_MINI_JA("solar-mini-ja"); + + public final String value; + + ChatModel(String value) { + this.value = value; + } + + public String getValue() { + return this.value; + } + + } + + /** + * Solar Embeddings Models: + * Embeddings. + */ + public enum EmbeddingModel { + + EMBEDDING_QUERY("embedding-query"), EMBEDDING_PASSAGE("embedding-passage"); + + public final String value; + + EmbeddingModel(String value) { + this.value = value; + } + + public String getValue() { + return this.value; + } + + } + + /** + * Creates a model response for the given chat conversation. + * + * @param messages A list of messages comprising the conversation so far. + * @param model The model name to generate the completion. Value in: "solar-pro" | + * "solar-mini" | "solar-mini-ja" + * @param maxTokens An optional parameter that limits the maximum number of tokens to + * generate. If max_tokens is set, sum of input tokens and max_tokens should be lower + * than or equal to context length of model. Default value is inf. + * @param stream An optional parameter that specifies whether a response should be + * sent as a stream. If set true, partial message deltas will be sent. Tokens will be + * sent as data-only server-sent events. Default value is false. + * @param temperature An optional parameter to set the sampling temperature. The value + * should lie between 0 and 2. Higher values like 0.8 result in a more random output, + * whereas lower values such as 0.2 enhance focus and determinism in the output. + * Default value is 0.7. not both. + * @param topP An optional parameter to trigger nucleus sampling. The tokens with + * top_p probability mass will be considered, which means, setting this value to 0.1 + * will consider tokens comprising the top 10% probability. + * @param responseFormat An object specifying the format that the model must generate. + * To generate JSON object without providing schema (JSON Mode), set response_format: + * {\"type\": \"json_object\"}. To generate JSON object with your own schema + * (Structured Outputs), set response_format: {“type”: “json_schema”, “json_schema”: { + * … your json schema … }}. + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + public record ChatCompletionRequest(@JsonProperty("messages") List messages, + @JsonProperty("model") String model, @JsonProperty("max_tokens") Integer maxTokens, + @JsonProperty("stream") Boolean stream, @JsonProperty("temperature") Double temperature, + @JsonProperty("top_p") Double topP, @JsonProperty("response_format") ResponseFormat responseFormat) { + + /** + * Shortcut constructor for a chat completion request with the given messages and + * model. + * @param messages A list of messages comprising the conversation so far. + * @param model ID of the model to use. + * @param temperature An optional parameter to set the sampling temperature. The + * value should lie between 0 and 2. Higher values like 0.8 result in a more + * random output, whereas lower values such as 0.2 enhance focus and determinism + * in the output. Default value is 0.7. + */ + public ChatCompletionRequest(List messages, String model, Double temperature) { + this(messages, model, null, null, temperature, null, null); + } + + /** + * Shortcut constructor for a chat completion request with the given messages, + * model and control for streaming. + * @param messages A list of messages comprising the conversation so far. + * @param model ID of the model to use. + * @param stream If set, partial message deltas will be sent.Tokens will be sent + * as data-only server-sent events as they become available, with the stream + * terminated by a data: [DONE] message. + */ + public ChatCompletionRequest(List messages, String model, boolean stream) { + this(messages, model, null, stream, null, null, null); + } + + /** + * Shortcut constructor for a chat completion request with the given messages, + * model and control for streaming. + * @param messages A list of messages comprising the conversation so far. + * @param model ID of the model to use. + * @param temperature An optional parameter to set the sampling temperature. The + * value should lie between 0 and 2. Higher values like 0.8 result in a more + * random output, whereas lower values such as 0.2 enhance focus and determinism + * in the output. Default value is 0.7. + * @param stream If set, partial message deltas will be sent.Tokens will be sent + * as data-only server-sent events as they become available, with the stream + * terminated by a data: [DONE] message. + */ + public ChatCompletionRequest(List messages, String model, Double temperature, + boolean stream) { + this(messages, model, null, stream, temperature, null, null); + } + + /** + * An object specifying the format that the model must output. + * + * @param type Must be one of 'json_object' or 'json_schema'. + * @param jsonSchema The JSON schema to be used for structured output. + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + public record ResponseFormat(@JsonProperty("type") String type, + @JsonProperty("json_schema") String jsonSchema) { + } + } + + /** + * Message comprising the conversation. + * + * @param rawContent The contents of the message. Can be a {@link String}. The + * response message content is always a {@link String}. + * @param role The role of the messages author. Could be one of the {@link Role} + * types. + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + public record ChatCompletionMessage(@JsonProperty("content") Object rawContent, @JsonProperty("role") Role role) { + + /** + * Get message content as String. + */ + public String content() { + if (this.rawContent == null) { + return null; + } + if (this.rawContent instanceof String text) { + return text; + } + throw new IllegalStateException("The content is not a string!"); + } + + /** + * The role of the author of this message. + */ + public enum Role { + + /** + * System message. + */ + @JsonProperty("system") + SYSTEM, + /** + * User message. + */ + @JsonProperty("user") + USER, + /** + * Assistant message. + */ + @JsonProperty("assistant") + ASSISTANT, + /** + * Tool message. + */ + @JsonProperty("tool") + TOOL + + } + } + + /** + * Represents a chat completion response returned by model, based on the provided + * input. + * + * @param id A unique identifier for the chat completion. Each chunk has the same ID. + * @param object The object type, which is always 'chat.completion'. + * @param created The Unix timestamp (in seconds) of when the chat completion was + * created. Each chunk has the same timestamp. + * @param model A string representing the version of the model being used. + * @param systemFingerprint This field is not yet available. + * @param choices A list of chat completion choices. + * @param usage Usage statistics for the completion request. + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + public record ChatCompletion(@JsonProperty("id") String id, @JsonProperty("object") String object, + @JsonProperty("created") Long created, @JsonProperty("model") String model, + @JsonProperty("system_fingerprint") Object systemFingerprint, @JsonProperty("choices") List choices, + @JsonProperty("usage") Usage usage) { + /** + * Choice statistics for the completion request. + * + * @param finishReason A unique identifier for the chat completion. Each chunk has + * the same ID. + * @param index The index of the choice in the list of choices. + * @param message A chat completion message generated by the model. + * @param logprobs This field is not yet available. + * @param usage Usage statistics for the completion request. + */ + public record Choice(@JsonProperty("finish_reason") String finishReason, @JsonProperty("index") int index, + @JsonProperty("message") Message message, @JsonProperty("logprobs") Object logprobs, + @JsonProperty("usage") Usage usage) { + } + + /** + * A chat completion message generated by the model. + * + * @param content The contents of the message. + * @param role The role of the author of this message. + * @param toolCalls A list of tools selected by model to call. + */ + public record Message(@JsonProperty("content") String content, @JsonProperty("role") String role, + @JsonProperty("tool_calls") ToolCalls toolCalls) { + } + + /** + * A list of tools selected by model to call. + * + * @param id The ID of tool calls. + * @param type The type of tool. + * @param function A function object to call. + */ + public record ToolCalls(@JsonProperty("id") String id, @JsonProperty("type") String type, + @JsonProperty("function") Function function) { + } + + /** + * A function object to call. + * + * @param name The name of function to call. + * @param arguments A JSON input to function. + */ + public record Function(@JsonProperty("name") String name, @JsonProperty("arguments") String arguments) { + } + + /** + * Usage statistics for the completion request. + * + * @param completionTokens Number of tokens in the generated completion. + * @param promptTokens Number of tokens in the prompt. + * @param totalTokens Total number of tokens used in the request (prompt + + * completion). + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + public record Usage(@JsonProperty("completion_tokens") Integer completionTokens, + @JsonProperty("prompt_tokens") Integer promptTokens, + @JsonProperty("total_tokens") Integer totalTokens) { + } + } + + /** + * Represents a streamed chunk of a chat completion response returned by model, based + * on the provided input. + * + * @param id A unique identifier for the chat completion. Each chunk has the same ID. + * @param object The object type, which is always 'chat.completion.chunk'. + * @param created The Unix timestamp (in seconds) of when the chat completion was + * created. Each chunk has the same timestamp. + * @param model A string representing the version of the model being used. + * @param systemFingerprint This field is not yet available. + * @param choices A list of chat completion choices. + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + public record ChatCompletionChunk(@JsonProperty("id") String id, @JsonProperty("object") String object, + @JsonProperty("created") Long created, @JsonProperty("model") String model, + @JsonProperty("system_fingerprint") Object systemFingerprint, + @JsonProperty("choices") List choices) { + /** + * A list of chat completion choices. + * + * @param finishReason The reason the model stopped generating tokens. This will + * be stop if the model hit a natural stop point or a provided stop sequence, + * length if the maximum number of tokens specified in the request was reached. + * @param index The index of the choice in the list of choices. + * @param delta A chat completion message generated by the model. + * @param logprobs This field is not yet available. + */ + public record Choice(@JsonProperty("finish_reason") String finishReason, @JsonProperty("index") int index, + @JsonProperty("delta") Delta delta, @JsonProperty("logprobs") Object logprobs) { + } + + /** + * A chat completion message generated by the model. + * + * @param content The contents of the message. + * @param role The role of the author of this message. + * @param toolCalls A list of tools selected by model to call. + */ + public record Delta(@JsonProperty("content") String content, @JsonProperty("role") String role, + @JsonProperty("tool_calls") ToolCalls toolCalls) { + } + + /** + * A list of tools selected by model to call. + * + * @param id The ID of tool calls. + * @param type The type of tool. + * @param function A function object to call. + */ + public record ToolCalls(@JsonProperty("id") String id, @JsonProperty("type") String type, + @JsonProperty("function") Function function) { + } + + /** + * A function object to call. + * + * @param name The name of function to call. + * @param arguments A JSON input to function. + */ + public record Function(@JsonProperty("name") String name, @JsonProperty("arguments") String arguments) { + } + + } + +} diff --git a/models/spring-ai-solar/src/main/java/org/springframework/ai/solar/api/SolarConstants.java b/models/spring-ai-solar/src/main/java/org/springframework/ai/solar/api/SolarConstants.java new file mode 100644 index 00000000000..c50e1c99c34 --- /dev/null +++ b/models/spring-ai-solar/src/main/java/org/springframework/ai/solar/api/SolarConstants.java @@ -0,0 +1,19 @@ +package org.springframework.ai.solar.api; + +import org.springframework.ai.observation.conventions.AiProvider; + +/** + * Common value constants for Solar api. + * + * @author Seunghyeon Ji + */ +public class SolarConstants { + + public static final String DEFAULT_BASE_URL = "https://api.upstage.ai"; + + public static final String PROVIDER_NAME = AiProvider.UPSTAGE.value(); + + private SolarConstants() { + } + +} diff --git a/models/spring-ai-solar/src/main/java/org/springframework/ai/solar/chat/SolarChatModel.java b/models/spring-ai-solar/src/main/java/org/springframework/ai/solar/chat/SolarChatModel.java new file mode 100644 index 00000000000..bed62e1d410 --- /dev/null +++ b/models/spring-ai-solar/src/main/java/org/springframework/ai/solar/chat/SolarChatModel.java @@ -0,0 +1,281 @@ +package org.springframework.ai.solar.chat; + +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.metadata.EmptyUsage; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.model.MessageAggregator; +import org.springframework.ai.chat.model.StreamingChatModel; +import org.springframework.ai.chat.observation.ChatModelObservationContext; +import org.springframework.ai.chat.observation.ChatModelObservationConvention; +import org.springframework.ai.chat.observation.ChatModelObservationDocumentation; +import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.ChatOptionsBuilder; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.ai.solar.api.SolarApi; +import org.springframework.ai.solar.api.SolarApi.ChatCompletion; +import org.springframework.ai.solar.api.SolarApi.ChatCompletionChunk; +import org.springframework.ai.solar.api.SolarApi.ChatCompletionMessage; +import org.springframework.ai.solar.api.SolarApi.ChatCompletionMessage.Role; +import org.springframework.ai.solar.api.SolarApi.ChatCompletionRequest; +import org.springframework.ai.solar.api.SolarConstants; +import org.springframework.ai.solar.metadata.SolarUsage; +import org.springframework.http.ResponseEntity; +import org.springframework.retry.support.RetryTemplate; +import org.springframework.util.Assert; + +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +public class SolarChatModel implements ChatModel, StreamingChatModel { + + private static final Logger logger = LoggerFactory.getLogger(SolarChatModel.class); + + private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention(); + + /** + * The retry template used to retry the Solar API calls. + */ + public final RetryTemplate retryTemplate; + + /** + * The default options used for the chat completion requests. + */ + private final SolarChatOptions defaultOptions; + + /** + * Low-level access to the Solar API. + */ + private final SolarApi solarApi; + + /** + * Observation registry used for instrumentation. + */ + private final ObservationRegistry observationRegistry; + + /** + * Conventions to use for generating observations. + */ + private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION; + + /** + * Creates an instance of the SolarChatModel. + * @param SolarApi The SolarApi instance to be used for interacting with the Solar + * Chat API. + * @throws IllegalArgumentException if SolarApi is null + */ + public SolarChatModel(SolarApi SolarApi) { + this(SolarApi, SolarChatOptions.builder().withModel(SolarApi.DEFAULT_CHAT_MODEL).withTemperature(0.7).build()); + } + + /** + * Initializes an instance of the SolarChatModel. + * @param SolarApi The SolarApi instance to be used for interacting with the Solar + * Chat API. + * @param options The SolarChatOptions to configure the chat client. + */ + public SolarChatModel(SolarApi SolarApi, SolarChatOptions options) { + this(SolarApi, options, RetryUtils.DEFAULT_RETRY_TEMPLATE); + } + + /** + * Initializes a new instance of the SolarChatModel. + * @param SolarApi The SolarApi instance to be used for interacting with the Solar + * Chat API. + * @param options The SolarChatOptions to configure the chat client. + * @param retryTemplate The retry template. + */ + public SolarChatModel(SolarApi SolarApi, SolarChatOptions options, RetryTemplate retryTemplate) { + this(SolarApi, options, retryTemplate, ObservationRegistry.NOOP); + } + + /** + * Initializes a new instance of the SolarChatModel. + * @param SolarApi The SolarApi instance to be used for interacting with the Solar + * Chat API. + * @param options The SolarChatOptions to configure the chat client. + * @param retryTemplate The retry template. + * @param observationRegistry The ObservationRegistry used for instrumentation. + */ + public SolarChatModel(SolarApi SolarApi, SolarChatOptions options, RetryTemplate retryTemplate, + ObservationRegistry observationRegistry) { + Assert.notNull(SolarApi, "SolarApi must not be null"); + Assert.notNull(options, "Options must not be null"); + Assert.notNull(retryTemplate, "RetryTemplate must not be null"); + Assert.notNull(observationRegistry, "ObservationRegistry must not be null"); + this.solarApi = SolarApi; + this.defaultOptions = options; + this.retryTemplate = retryTemplate; + this.observationRegistry = observationRegistry; + } + + @Override + public ChatResponse call(Prompt prompt) { + ChatCompletionRequest request = createRequest(prompt, false); + + ChatModelObservationContext observationContext = ChatModelObservationContext.builder() + .prompt(prompt) + .provider(SolarConstants.PROVIDER_NAME) + .requestOptions(buildRequestOptions(request)) + .build(); + + return ChatModelObservationDocumentation.CHAT_MODEL_OPERATION + .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, + this.observationRegistry) + .observe(() -> { + ResponseEntity completionEntity = this.retryTemplate + .execute(ctx -> this.solarApi.chatCompletionEntity(request)); + + var chatCompletion = completionEntity.getBody(); + if (chatCompletion == null) { + logger.warn("No chat completion returned for prompt: {}", prompt); + return new ChatResponse(List.of()); + } + + // @formatter:off + Map metadata = Map.of( + "id", chatCompletion.id(), + "role", SolarApi.ChatCompletionMessage.Role.ASSISTANT + ); + // @formatter:on + + var assistantMessage = new AssistantMessage(chatCompletion.choices().get(0).message().content(), + metadata); + List generations = Collections.singletonList(new Generation(assistantMessage)); + ChatResponse chatResponse = new ChatResponse(generations, from(chatCompletion, request.model())); + observationContext.setResponse(chatResponse); + return chatResponse; + }); + } + + @Override + public Flux stream(Prompt prompt) { + return Flux.deferContextual(contextView -> { + ChatCompletionRequest request = createRequest(prompt, true); + + var completionChunks = this.solarApi.chatCompletionStream(request); + + final ChatModelObservationContext observationContext = ChatModelObservationContext.builder() + .prompt(prompt) + .provider(SolarConstants.PROVIDER_NAME) + .requestOptions(buildRequestOptions(request)) + .build(); + + Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation( + this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, + this.observationRegistry); + + observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start(); + + Flux chatResponse = completionChunks.map(this::toChatCompletion) + .switchMap(chatCompletion -> Mono.just(chatCompletion).map(chatCompletion2 -> { + // @formatter:off + Map metadata = Map.of( + "id", chatCompletion.id(), + "role", Role.ASSISTANT + ); + // @formatter:on + + var assistantMessage = new AssistantMessage(chatCompletion.choices().get(0).delta().content(), + metadata); + List generations = Collections.singletonList(new Generation(assistantMessage)); + return new ChatResponse(generations, from(chatCompletion, request.model())); + })) + .doOnError(observation::error) + .doFinally(s -> observation.stop()) + .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); + return new MessageAggregator().aggregate(chatResponse, observationContext::setResponse); + + }); + } + + private ChatCompletionChunk toChatCompletion(ChatCompletionChunk chunk) { + return new ChatCompletionChunk(chunk.id(), chunk.object(), chunk.created(), chunk.model(), + chunk.systemFingerprint(), chunk.choices()); + } + + public ChatCompletionRequest createRequest(Prompt prompt, boolean stream) { + var chatCompletionMessages = prompt.getInstructions() + .stream() + .map(m -> new ChatCompletionMessage(m.getContent(), + ChatCompletionMessage.Role.valueOf(m.getMessageType().name()))) + .toList(); + var systemMessageList = chatCompletionMessages.stream() + .filter(msg -> msg.role() == ChatCompletionMessage.Role.SYSTEM) + .toList(); + var userMessageList = chatCompletionMessages.stream() + .filter(msg -> msg.role() != ChatCompletionMessage.Role.SYSTEM) + .toList(); + + if (systemMessageList.size() > 1) { + throw new IllegalArgumentException("Only one system message is allowed in the prompt"); + } + + var systemMessage = systemMessageList.isEmpty() ? null : systemMessageList.get(0).content(); + + var request = new SolarApi.ChatCompletionRequest(userMessageList, systemMessage, stream); + + if (this.defaultOptions != null) { + request = ModelOptionsUtils.merge(this.defaultOptions, request, SolarApi.ChatCompletionRequest.class); + } + + if (prompt.getOptions() != null) { + var updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class, + SolarChatOptions.class); + request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, SolarApi.ChatCompletionRequest.class); + } + return request; + } + + @Override + public ChatOptions getDefaultOptions() { + return SolarChatOptions.fromOptions(this.defaultOptions); + } + + private ChatOptions buildRequestOptions(ChatCompletionRequest request) { + return ChatOptionsBuilder.builder() + .withModel(request.model()) + .withMaxTokens(request.maxTokens()) + .withTemperature(request.temperature()) + .withTopP(request.topP()) + .build(); + } + + private ChatResponseMetadata from(ChatCompletion result, String model) { + Assert.notNull(result, "Solar ChatCompletionResult must not be null"); + return ChatResponseMetadata.builder() + .withId(result.id() != null ? result.id() : "") + .withUsage(result.usage() != null ? SolarUsage.from(result.usage()) : new EmptyUsage()) + .withModel(model) + .withKeyValue("created", result.created() != null ? result.created() : 0L) + .build(); + } + + private ChatResponseMetadata from(ChatCompletionChunk result, String model) { + Assert.notNull(result, "Solar ChatCompletionResult must not be null"); + return ChatResponseMetadata.builder() + .withId(result.id() != null ? result.id() : "") + .withModel(model) + .withKeyValue("created", result.created() != null ? result.created() : 0L) + .build(); + } + + public void setObservationConvention(ChatModelObservationConvention observationConvention) { + this.observationConvention = observationConvention; + } + +} diff --git a/models/spring-ai-solar/src/main/java/org/springframework/ai/solar/chat/SolarChatOptions.java b/models/spring-ai-solar/src/main/java/org/springframework/ai/solar/chat/SolarChatOptions.java new file mode 100644 index 00000000000..ff8970ce455 --- /dev/null +++ b/models/spring-ai-solar/src/main/java/org/springframework/ai/solar/chat/SolarChatOptions.java @@ -0,0 +1,328 @@ +package org.springframework.ai.solar.chat; + +import java.util.List; + +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.solar.api.SolarApi; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * SolarChatOptions represents the options for performing chat completion using the Solar + * API. It provides methods to set and retrieve various options like model, frequency + * penalty, max tokens, etc. + * + * @author Seunghyeon Ji + * @since 1.0 + * @see ChatOptions + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class SolarChatOptions implements ChatOptions { + + // @formatter:off + /** + * ID of the model to use. + */ + private @JsonProperty("model") String model; + /** + * Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing + * frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. + */ + private @JsonProperty("frequency_penalty") Double frequencyPenalty; + /** + * The maximum number of tokens to generate in the chat completion. The total length of input + * tokens and generated tokens is limited by the model's context length. + */ + private @JsonProperty("max_output_tokens") Integer maxTokens; + /** + * Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they + * appear in the text so far, increasing the model's likelihood to talk about new topics. + */ + private @JsonProperty("presence_penalty") Double presencePenalty; + /** + * An object specifying the format that the model must output. Setting to { "type": + * "json_object" } enables JSON mode, which guarantees the message the model generates is valid JSON. + */ + private @JsonProperty("response_format") SolarApi.ChatCompletionRequest.ResponseFormat responseFormat; + /** + * Up to 4 sequences where the API will stop generating further tokens. + */ + private @JsonProperty("stop") List stop; + /** + * What sampling temperature to use, between 0 and 2. Higher values like 0.7 will make the output + * more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend + * altering this or top_p but not both. + */ + private @JsonProperty("temperature") Double temperature; + /** + * An alternative to sampling with temperature, called nucleus sampling, where the model considers the + * results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% + * probability mass are considered. We generally recommend altering this or temperature but not both. + */ + private @JsonProperty("top_p") Double topP; + // @formatter:on + + public static Builder builder() { + return new Builder(); + } + + public static SolarChatOptions fromOptions(SolarChatOptions fromOptions) { + return SolarChatOptions.builder() + .withModel(fromOptions.getModel()) + .withFrequencyPenalty(fromOptions.getFrequencyPenalty()) + .withMaxTokens(fromOptions.getMaxTokens()) + .withPresencePenalty(fromOptions.getPresencePenalty()) + .withResponseFormat(fromOptions.getResponseFormat()) + .withStop(fromOptions.getStop()) + .withTemperature(fromOptions.getTemperature()) + .withTopP(fromOptions.getTopP()) + .build(); + } + + @Override + public String getModel() { + return this.model; + } + + public void setModel(String model) { + this.model = model; + } + + @Override + public Double getFrequencyPenalty() { + return this.frequencyPenalty; + } + + public void setFrequencyPenalty(Double frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + } + + @Override + public Integer getMaxTokens() { + return this.maxTokens; + } + + public void setMaxTokens(Integer maxTokens) { + this.maxTokens = maxTokens; + } + + @Override + public Double getPresencePenalty() { + return this.presencePenalty; + } + + public void setPresencePenalty(Double presencePenalty) { + this.presencePenalty = presencePenalty; + } + + public SolarApi.ChatCompletionRequest.ResponseFormat getResponseFormat() { + return this.responseFormat; + } + + public void setResponseFormat(SolarApi.ChatCompletionRequest.ResponseFormat responseFormat) { + this.responseFormat = responseFormat; + } + + @Override + @JsonIgnore + public List getStopSequences() { + return getStop(); + } + + @JsonIgnore + public void setStopSequences(List stopSequences) { + setStop(stopSequences); + } + + public List getStop() { + return this.stop; + } + + public void setStop(List stop) { + this.stop = stop; + } + + @Override + public Double getTemperature() { + return this.temperature; + } + + public void setTemperature(Double temperature) { + this.temperature = temperature; + } + + @Override + public Double getTopP() { + return this.topP; + } + + public void setTopP(Double topP) { + this.topP = topP; + } + + @Override + @JsonIgnore + public Integer getTopK() { + return null; + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + ((this.model == null) ? 0 : this.model.hashCode()); + result = prime * result + ((this.frequencyPenalty == null) ? 0 : this.frequencyPenalty.hashCode()); + result = prime * result + ((this.maxTokens == null) ? 0 : this.maxTokens.hashCode()); + result = prime * result + ((this.presencePenalty == null) ? 0 : this.presencePenalty.hashCode()); + result = prime * result + ((this.responseFormat == null) ? 0 : this.responseFormat.hashCode()); + result = prime * result + ((this.stop == null) ? 0 : this.stop.hashCode()); + result = prime * result + ((this.temperature == null) ? 0 : this.temperature.hashCode()); + result = prime * result + ((this.topP == null) ? 0 : this.topP.hashCode()); + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null) { + return false; + } + if (getClass() != obj.getClass()) { + return false; + } + SolarChatOptions other = (SolarChatOptions) obj; + if (this.model == null) { + if (other.model != null) { + return false; + } + } + else if (!this.model.equals(other.model)) { + return false; + } + if (this.frequencyPenalty == null) { + if (other.frequencyPenalty != null) { + return false; + } + } + else if (!this.frequencyPenalty.equals(other.frequencyPenalty)) { + return false; + } + if (this.maxTokens == null) { + if (other.maxTokens != null) { + return false; + } + } + else if (!this.maxTokens.equals(other.maxTokens)) { + return false; + } + if (this.presencePenalty == null) { + if (other.presencePenalty != null) { + return false; + } + } + else if (!this.presencePenalty.equals(other.presencePenalty)) { + return false; + } + if (this.responseFormat == null) { + if (other.responseFormat != null) { + return false; + } + } + else if (!this.responseFormat.equals(other.responseFormat)) { + return false; + } + if (this.stop == null) { + if (other.stop != null) { + return false; + } + } + else if (!this.stop.equals(other.stop)) { + return false; + } + if (this.temperature == null) { + if (other.temperature != null) { + return false; + } + } + else if (!this.temperature.equals(other.temperature)) { + return false; + } + if (this.topP == null) { + if (other.topP != null) { + return false; + } + } + else if (!this.topP.equals(other.topP)) { + return false; + } + return true; + } + + @Override + public SolarChatOptions copy() { + return fromOptions(this); + } + + public static class Builder { + + protected SolarChatOptions options; + + public Builder() { + this.options = new SolarChatOptions(); + } + + public Builder(SolarChatOptions options) { + this.options = options; + } + + public Builder withModel(String model) { + this.options.model = model; + return this; + } + + public Builder withFrequencyPenalty(Double frequencyPenalty) { + this.options.frequencyPenalty = frequencyPenalty; + return this; + } + + public Builder withMaxTokens(Integer maxTokens) { + this.options.maxTokens = maxTokens; + return this; + } + + public Builder withPresencePenalty(Double presencePenalty) { + this.options.presencePenalty = presencePenalty; + return this; + } + + public Builder withResponseFormat(SolarApi.ChatCompletionRequest.ResponseFormat responseFormat) { + this.options.responseFormat = responseFormat; + return this; + } + + public Builder withStop(List stop) { + this.options.stop = stop; + return this; + } + + public Builder withTemperature(Double temperature) { + this.options.temperature = temperature; + return this; + } + + public Builder withTopP(Double topP) { + this.options.topP = topP; + return this; + } + + public SolarChatOptions build() { + return this.options; + } + + } + +} diff --git a/models/spring-ai-solar/src/main/java/org/springframework/ai/solar/metadata/SolarUsage.java b/models/spring-ai-solar/src/main/java/org/springframework/ai/solar/metadata/SolarUsage.java new file mode 100644 index 00000000000..962c9e2f8da --- /dev/null +++ b/models/spring-ai-solar/src/main/java/org/springframework/ai/solar/metadata/SolarUsage.java @@ -0,0 +1,65 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.solar.metadata; + +import org.springframework.ai.chat.metadata.Usage; +import org.springframework.ai.solar.api.SolarApi; +import org.springframework.util.Assert; + +/** + * {@link Usage} implementation for {@literal Solar}. + * + * @author Seunghyeon Ji + */ +public class SolarUsage implements Usage { + + private final SolarApi.ChatCompletion.Usage usage; + + protected SolarUsage(SolarApi.ChatCompletion.Usage usage) { + Assert.notNull(usage, "Solar Usage must not be null"); + this.usage = usage; + } + + public static SolarUsage from(SolarApi.ChatCompletion.Usage usage) { + return new SolarUsage(usage); + } + + protected SolarApi.ChatCompletion.Usage getUsage() { + return this.usage; + } + + @Override + public Long getPromptTokens() { + return getUsage().promptTokens().longValue(); + } + + @Override + public Long getGenerationTokens() { + return 0L; + } + + @Override + public Long getTotalTokens() { + return getUsage().totalTokens().longValue(); + } + + @Override + public String toString() { + return getUsage().toString(); + } + +} diff --git a/models/spring-ai-solar/src/main/resources/META-INF/spring/aot.factories b/models/spring-ai-solar/src/main/resources/META-INF/spring/aot.factories new file mode 100644 index 00000000000..8ddbd00cf52 --- /dev/null +++ b/models/spring-ai-solar/src/main/resources/META-INF/spring/aot.factories @@ -0,0 +1,2 @@ +org.springframework.aot.hint.RuntimeHintsRegistrar=\ + org.springframework.ai.solar.aot.SolarRuntimeHints diff --git a/models/spring-ai-solar/src/test/java/org/springframework/ai/solar/api/SolarApiIT.java b/models/spring-ai-solar/src/test/java/org/springframework/ai/solar/api/SolarApiIT.java new file mode 100644 index 00000000000..6d5b31b61fe --- /dev/null +++ b/models/spring-ai-solar/src/test/java/org/springframework/ai/solar/api/SolarApiIT.java @@ -0,0 +1,64 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.solar.api; + +import static org.assertj.core.api.Assertions.*; + +import java.util.List; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariables; +import org.springframework.http.ResponseEntity; + +import org.springframework.ai.solar.api.SolarApi.ChatCompletionChunk; +import org.springframework.ai.solar.api.SolarApi.ChatCompletionMessage; +import org.springframework.ai.solar.api.SolarApi.ChatCompletionMessage.Role; +import org.springframework.ai.solar.api.SolarApi.ChatCompletionRequest; + +import reactor.core.publisher.Flux; + +/** + * @author Seunghyeon Ji + */ +@EnabledIfEnvironmentVariables({ @EnabledIfEnvironmentVariable(named = "SOLAR_API_KEY", matches = ".+") }) +public class SolarApiIT { + + SolarApi solarApi = new SolarApi(System.getenv("SOLAR_API_KEY")); + + @Test + void chatCompletionEntity() { + ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", + SolarApi.ChatCompletionMessage.Role.USER); + ResponseEntity response = this.solarApi.chatCompletionEntity( + new SolarApi.ChatCompletionRequest(List.of(chatCompletionMessage), SolarApi.DEFAULT_CHAT_MODEL, false)); + + assertThat(response).isNotNull(); + assertThat(response.getBody()).isNotNull(); + } + + @Test + void chatCompletionStream() { + ChatCompletionMessage chatCompletionMessage = new SolarApi.ChatCompletionMessage("Hello world", Role.USER); + Flux response = this.solarApi.chatCompletionStream( + new ChatCompletionRequest(List.of(chatCompletionMessage), SolarApi.ChatModel.SOLAR_PRO.value, true)); + + assertThat(response).isNotNull(); + assertThat(response.collectList().block()).isNotNull(); + } + +} diff --git a/models/spring-ai-solar/src/test/java/org/springframework/ai/solar/chat/SolarChatModelIT.java b/models/spring-ai-solar/src/test/java/org/springframework/ai/solar/chat/SolarChatModelIT.java new file mode 100644 index 00000000000..887acb5ef5b --- /dev/null +++ b/models/spring-ai-solar/src/test/java/org/springframework/ai/solar/chat/SolarChatModelIT.java @@ -0,0 +1,95 @@ +/* +* Copyright 2023-2024 the original author or authors. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* https://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.springframework.ai.solar.chat; + +import static org.assertj.core.api.Assertions.*; + +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariables; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.model.StreamingChatModel; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.core.io.Resource; + +import reactor.core.publisher.Flux; + +/** + * @author Seunghyeon Ji + */ +@SpringBootTest(classes = SolarTestConfiguration.class) +@EnabledIfEnvironmentVariables({ @EnabledIfEnvironmentVariable(named = "SOLAR_API_KEY", matches = ".+") }) +public class SolarChatModelIT { + + @Autowired + protected ChatModel chatModel; + + @Autowired + protected StreamingChatModel streamingChatModel; + + @Value("classpath:/prompts/system-message.st") + private Resource systemResource; + + @Test + void roleTest() { + UserMessage userMessage = new UserMessage( + "Tell me about three famous pirates from the Golden Age of Piracy in english, focusing on their original nicknames and what they did."); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); + Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); + Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); + ChatResponse response = this.chatModel.call(prompt); + + assertThat(response.getResults()).hasSize(1); + assertThat(response.getResults().get(0).getOutput().getContent()).contains("Blackbeard"); + } + + @Test + void streamRoleTest() { + UserMessage userMessage = new UserMessage( + "Tell me about three famous pirates from the Golden Age of Piracy in english, focusing on their original nicknames and what they did."); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); + Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); + Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); + Flux flux = this.streamingChatModel.stream(prompt); + + List responses = flux.collectList().block(); + assertThat(responses.size()).isGreaterThan(1); + + String stitchedResponseContent = responses.stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) + .collect(Collectors.joining()); + + assertThat(stitchedResponseContent).contains("Blackbeard"); + } + +} diff --git a/models/spring-ai-solar/src/test/java/org/springframework/ai/solar/chat/SolarTestConfiguration.java b/models/spring-ai-solar/src/test/java/org/springframework/ai/solar/chat/SolarTestConfiguration.java new file mode 100644 index 00000000000..d7346f7775e --- /dev/null +++ b/models/spring-ai-solar/src/test/java/org/springframework/ai/solar/chat/SolarTestConfiguration.java @@ -0,0 +1,33 @@ +package org.springframework.ai.solar.chat; + +import org.springframework.ai.solar.api.SolarApi; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.context.annotation.Bean; +import org.springframework.util.StringUtils; + +/** + * @author Seunghyeon Ji + */ +@SpringBootConfiguration +public class SolarTestConfiguration { + + @Bean + public SolarApi solarApi() { + return new SolarApi(getApiKey()); + } + + private String getApiKey() { + String apiKey = System.getenv("SOLAR_API_KEY"); + if (!StringUtils.hasText(apiKey)) { + throw new IllegalArgumentException( + "You must provide an API key. Put it in an environment variable under the name SOLAR_API_KEY"); + } + return apiKey; + } + + @Bean + public SolarChatModel solarChatModel(SolarApi api) { + return new SolarChatModel(api); + } + +} diff --git a/models/spring-ai-solar/src/test/resources/prompts/system-message.st b/models/spring-ai-solar/src/test/resources/prompts/system-message.st new file mode 100644 index 00000000000..579febd8d9b --- /dev/null +++ b/models/spring-ai-solar/src/test/resources/prompts/system-message.st @@ -0,0 +1,3 @@ +You are an AI assistant that helps people find information. +Your name is {name}. +You should reply to the user's request with your name and also in the style of a {voice}. \ No newline at end of file diff --git a/pom.xml b/pom.xml index 898f27b7a7e..035780bcf15 100644 --- a/pom.xml +++ b/pom.xml @@ -99,6 +99,7 @@ models/spring-ai-openai models/spring-ai-postgresml models/spring-ai-qianfan + models/spring-ai-solar models/spring-ai-stability-ai models/spring-ai-transformers models/spring-ai-vertex-ai-embedding diff --git a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java index e723b679b02..b87aec4d10f 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java @@ -99,7 +99,12 @@ public enum AiProvider { /** * AI system provided by ONNX. */ - ONNX("onnx"); + ONNX("onnx"), + + /** + * AI system provided by Upstage. + */ + UPSTAGE("upstage"); private final String value;