diff --git a/models/spring-ai-solar/README.md b/models/spring-ai-solar/README.md new file mode 100644 index 00000000000..a758f25d842 --- /dev/null +++ b/models/spring-ai-solar/README.md @@ -0,0 +1,3 @@ +[Solar Chat Documentation](https://console.upstage.ai/docs/capabilities/chat) + +[Solar Embedding Documentation](https://console.upstage.ai/docs/capabilities/embeddings) diff --git a/models/spring-ai-solar/pom.xml b/models/spring-ai-solar/pom.xml new file mode 100644 index 00000000000..21df7b6eca3 --- /dev/null +++ b/models/spring-ai-solar/pom.xml @@ -0,0 +1,84 @@ + + + + + 4.0.0 + + org.springframework.ai + spring-ai + 1.0.0-SNAPSHOT + ../../pom.xml + + spring-ai-solar + jar + Spring AI Solar + Upstage Solar support + https://github.com/spring-projects/spring-ai + + + https://github.com/spring-projects/spring-ai + git://github.com/spring-projects/spring-ai.git + git@github.com:spring-projects/spring-ai.git + + + + + + + + + + org.springframework.ai + spring-ai-core + ${project.parent.version} + + + + org.springframework.ai + spring-ai-retry + ${project.parent.version} + + + + + org.springframework + spring-context-support + + + + org.springframework.boot + spring-boot-starter-logging + + + + + org.springframework.ai + spring-ai-test + ${project.version} + test + + + + io.micrometer + micrometer-observation-test + test + + + + + diff --git a/models/spring-ai-solar/src/main/java/org/springframework/ai/solar/SolarChatModel.java b/models/spring-ai-solar/src/main/java/org/springframework/ai/solar/SolarChatModel.java new file mode 100644 index 00000000000..003e8379563 --- /dev/null +++ b/models/spring-ai-solar/src/main/java/org/springframework/ai/solar/SolarChatModel.java @@ -0,0 +1,227 @@ +package org.springframework.ai.solar; + +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.metadata.EmptyUsage; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.model.MessageAggregator; +import org.springframework.ai.chat.model.StreamingChatModel; +import org.springframework.ai.chat.observation.ChatModelObservationContext; +import org.springframework.ai.chat.observation.ChatModelObservationConvention; +import org.springframework.ai.chat.observation.ChatModelObservationDocumentation; +import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.ChatOptionsBuilder; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.ai.solar.api.SolarApi; +import org.springframework.ai.solar.api.common.SolarConstants; +import org.springframework.ai.solar.metadata.SolarUsage; +import org.springframework.http.ResponseEntity; +import org.springframework.retry.support.RetryTemplate; +import org.springframework.util.Assert; + +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +public class SolarChatModel implements ChatModel, StreamingChatModel { + + private static final Logger logger = LoggerFactory.getLogger(SolarChatModel.class); + + private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention(); + + /** + * The retry template used to retry the Solar API calls. + */ + public final RetryTemplate retryTemplate; + + /** + * The default options used for the chat completion requests. + */ + private final SolarChatOptions defaultOptions; + + /** + * Low-level access to the Solar API. + */ + private final SolarApi solarApi; + + /** + * Observation registry used for instrumentation. + */ + private final ObservationRegistry observationRegistry; + + /** + * Conventions to use for generating observations. + */ + private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION; + + /** + * Creates an instance of the SolarChatModel. + * @param SolarApi The SolarApi instance to be used for interacting with the Solar + * Chat API. + * @throws IllegalArgumentException if SolarApi is null + */ + public SolarChatModel(SolarApi SolarApi) { + this(SolarApi, SolarChatOptions.builder().withModel(SolarApi.DEFAULT_CHAT_MODEL).withTemperature(0.7).build()); + } + + /** + * Initializes an instance of the SolarChatModel. + * @param SolarApi The SolarApi instance to be used for interacting with the Solar + * Chat API. + * @param options The SolarChatOptions to configure the chat client. + */ + public SolarChatModel(SolarApi SolarApi, SolarChatOptions options) { + this(SolarApi, options, RetryUtils.DEFAULT_RETRY_TEMPLATE); + } + + /** + * Initializes a new instance of the SolarChatModel. + * @param SolarApi The SolarApi instance to be used for interacting with the Solar + * Chat API. + * @param options The SolarChatOptions to configure the chat client. + * @param retryTemplate The retry template. + */ + public SolarChatModel(SolarApi SolarApi, SolarChatOptions options, RetryTemplate retryTemplate) { + this(SolarApi, options, retryTemplate, ObservationRegistry.NOOP); + } + + /** + * Initializes a new instance of the SolarChatModel. + * @param SolarApi The SolarApi instance to be used for interacting with the Solar + * Chat API. + * @param options The SolarChatOptions to configure the chat client. + * @param retryTemplate The retry template. + * @param observationRegistry The ObservationRegistry used for instrumentation. + */ + public SolarChatModel(SolarApi SolarApi, SolarChatOptions options, RetryTemplate retryTemplate, + ObservationRegistry observationRegistry) { + Assert.notNull(SolarApi, "SolarApi must not be null"); + Assert.notNull(options, "Options must not be null"); + Assert.notNull(retryTemplate, "RetryTemplate must not be null"); + Assert.notNull(observationRegistry, "ObservationRegistry must not be null"); + this.solarApi = SolarApi; + this.defaultOptions = options; + this.retryTemplate = retryTemplate; + this.observationRegistry = observationRegistry; + } + + @Override + public ChatResponse call(Prompt prompt) { + SolarApi.ChatCompletionRequest request = createRequest(prompt, false); + + ChatModelObservationContext observationContext = ChatModelObservationContext.builder() + .prompt(prompt) + .provider(SolarConstants.PROVIDER_NAME) + .requestOptions(buildRequestOptions(request)) + .build(); + + return ChatModelObservationDocumentation.CHAT_MODEL_OPERATION + .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, + this.observationRegistry) + .observe(() -> { + ResponseEntity completionEntity = this.retryTemplate + .execute(ctx -> this.solarApi.chatCompletionEntity(request)); + + var chatCompletion = completionEntity.getBody(); + if (chatCompletion == null) { + logger.warn("No chat completion returned for prompt: {}", prompt); + return new ChatResponse(List.of()); + } + + // @formatter:off + Map metadata = Map.of( + "id", chatCompletion.id(), + "role", SolarApi.ChatCompletionMessage.Role.ASSISTANT + ); + // @formatter:on + + var assistantMessage = new AssistantMessage(chatCompletion.choices().get(0).message().content(), + metadata); + List generations = Collections.singletonList(new Generation(assistantMessage)); + ChatResponse chatResponse = new ChatResponse(generations, from(chatCompletion, request.model())); + observationContext.setResponse(chatResponse); + return chatResponse; + }); + } + + /** + * Accessible for testing. + */ + public SolarApi.ChatCompletionRequest createRequest(Prompt prompt, boolean stream) { + var chatCompletionMessages = prompt.getInstructions() + .stream() + .map(m -> new SolarApi.ChatCompletionMessage(m.getContent(), + SolarApi.ChatCompletionMessage.Role.valueOf(m.getMessageType().name()))) + .toList(); + var systemMessageList = chatCompletionMessages.stream() + .filter(msg -> msg.role() == SolarApi.ChatCompletionMessage.Role.SYSTEM) + .toList(); + var userMessageList = chatCompletionMessages.stream() + .filter(msg -> msg.role() != SolarApi.ChatCompletionMessage.Role.SYSTEM) + .toList(); + + if (systemMessageList.size() > 1) { + throw new IllegalArgumentException("Only one system message is allowed in the prompt"); + } + + var systemMessage = systemMessageList.isEmpty() ? null : systemMessageList.get(0).content(); + + var request = new SolarApi.ChatCompletionRequest(userMessageList, systemMessage, stream); + + if (this.defaultOptions != null) { + request = ModelOptionsUtils.merge(this.defaultOptions, request, SolarApi.ChatCompletionRequest.class); + } + + if (prompt.getOptions() != null) { + var updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class, + SolarChatOptions.class); + request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, SolarApi.ChatCompletionRequest.class); + } + return request; + } + + @Override + public ChatOptions getDefaultOptions() { + return SolarChatOptions.fromOptions(this.defaultOptions); + } + + private ChatOptions buildRequestOptions(SolarApi.ChatCompletionRequest request) { + return ChatOptionsBuilder.builder() + .withModel(request.model()) + .withFrequencyPenalty(request.frequencyPenalty()) + .withMaxTokens(request.maxTokens()) + .withPresencePenalty(request.presencePenalty()) + .withStopSequences(request.stop()) + .withTemperature(request.temperature()) + .withTopP(request.topP()) + .build(); + } + + private ChatResponseMetadata from(SolarApi.ChatCompletion result, String model) { + Assert.notNull(result, "Solar ChatCompletionResult must not be null"); + return ChatResponseMetadata.builder() + .withId(result.id() != null ? result.id() : "") + .withUsage(result.usage() != null ? SolarUsage.from(result.usage()) : new EmptyUsage()) + .withModel(model) + .withKeyValue("created", result.created() != null ? result.created() : 0L) + .build(); + } + + public void setObservationConvention(ChatModelObservationConvention observationConvention) { + this.observationConvention = observationConvention; + } + +} diff --git a/models/spring-ai-solar/src/main/java/org/springframework/ai/solar/SolarChatOptions.java b/models/spring-ai-solar/src/main/java/org/springframework/ai/solar/SolarChatOptions.java new file mode 100644 index 00000000000..0f88eabaaf3 --- /dev/null +++ b/models/spring-ai-solar/src/main/java/org/springframework/ai/solar/SolarChatOptions.java @@ -0,0 +1,328 @@ +package org.springframework.ai.solar; + +import java.util.List; + +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.solar.api.SolarApi; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * SolarChatOptions represents the options for performing chat completion using the Solar + * API. It provides methods to set and retrieve various options like model, frequency + * penalty, max tokens, etc. + * + * @author Seunghyeon Ji + * @since 1.0 + * @see ChatOptions + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class SolarChatOptions implements ChatOptions { + + // @formatter:off + /** + * ID of the model to use. + */ + private @JsonProperty("model") String model; + /** + * Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing + * frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. + */ + private @JsonProperty("frequency_penalty") Double frequencyPenalty; + /** + * The maximum number of tokens to generate in the chat completion. The total length of input + * tokens and generated tokens is limited by the model's context length. + */ + private @JsonProperty("max_output_tokens") Integer maxTokens; + /** + * Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they + * appear in the text so far, increasing the model's likelihood to talk about new topics. + */ + private @JsonProperty("presence_penalty") Double presencePenalty; + /** + * An object specifying the format that the model must output. Setting to { "type": + * "json_object" } enables JSON mode, which guarantees the message the model generates is valid JSON. + */ + private @JsonProperty("response_format") SolarApi.ChatCompletionRequest.ResponseFormat responseFormat; + /** + * Up to 4 sequences where the API will stop generating further tokens. + */ + private @JsonProperty("stop") List stop; + /** + * What sampling temperature to use, between 0 and 1. Higher values like 0.8 will make the output + * more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend + * altering this or top_p but not both. + */ + private @JsonProperty("temperature") Double temperature; + /** + * An alternative to sampling with temperature, called nucleus sampling, where the model considers the + * results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% + * probability mass are considered. We generally recommend altering this or temperature but not both. + */ + private @JsonProperty("top_p") Double topP; + // @formatter:on + + public static Builder builder() { + return new Builder(); + } + + public static SolarChatOptions fromOptions(SolarChatOptions fromOptions) { + return SolarChatOptions.builder() + .withModel(fromOptions.getModel()) + .withFrequencyPenalty(fromOptions.getFrequencyPenalty()) + .withMaxTokens(fromOptions.getMaxTokens()) + .withPresencePenalty(fromOptions.getPresencePenalty()) + .withResponseFormat(fromOptions.getResponseFormat()) + .withStop(fromOptions.getStop()) + .withTemperature(fromOptions.getTemperature()) + .withTopP(fromOptions.getTopP()) + .build(); + } + + @Override + public String getModel() { + return this.model; + } + + public void setModel(String model) { + this.model = model; + } + + @Override + public Double getFrequencyPenalty() { + return this.frequencyPenalty; + } + + public void setFrequencyPenalty(Double frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + } + + @Override + public Integer getMaxTokens() { + return this.maxTokens; + } + + public void setMaxTokens(Integer maxTokens) { + this.maxTokens = maxTokens; + } + + @Override + public Double getPresencePenalty() { + return this.presencePenalty; + } + + public void setPresencePenalty(Double presencePenalty) { + this.presencePenalty = presencePenalty; + } + + public SolarApi.ChatCompletionRequest.ResponseFormat getResponseFormat() { + return this.responseFormat; + } + + public void setResponseFormat(SolarApi.ChatCompletionRequest.ResponseFormat responseFormat) { + this.responseFormat = responseFormat; + } + + @Override + @JsonIgnore + public List getStopSequences() { + return getStop(); + } + + @JsonIgnore + public void setStopSequences(List stopSequences) { + setStop(stopSequences); + } + + public List getStop() { + return this.stop; + } + + public void setStop(List stop) { + this.stop = stop; + } + + @Override + public Double getTemperature() { + return this.temperature; + } + + public void setTemperature(Double temperature) { + this.temperature = temperature; + } + + @Override + public Double getTopP() { + return this.topP; + } + + public void setTopP(Double topP) { + this.topP = topP; + } + + @Override + @JsonIgnore + public Integer getTopK() { + return null; + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + ((this.model == null) ? 0 : this.model.hashCode()); + result = prime * result + ((this.frequencyPenalty == null) ? 0 : this.frequencyPenalty.hashCode()); + result = prime * result + ((this.maxTokens == null) ? 0 : this.maxTokens.hashCode()); + result = prime * result + ((this.presencePenalty == null) ? 0 : this.presencePenalty.hashCode()); + result = prime * result + ((this.responseFormat == null) ? 0 : this.responseFormat.hashCode()); + result = prime * result + ((this.stop == null) ? 0 : this.stop.hashCode()); + result = prime * result + ((this.temperature == null) ? 0 : this.temperature.hashCode()); + result = prime * result + ((this.topP == null) ? 0 : this.topP.hashCode()); + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null) { + return false; + } + if (getClass() != obj.getClass()) { + return false; + } + SolarChatOptions other = (SolarChatOptions) obj; + if (this.model == null) { + if (other.model != null) { + return false; + } + } + else if (!this.model.equals(other.model)) { + return false; + } + if (this.frequencyPenalty == null) { + if (other.frequencyPenalty != null) { + return false; + } + } + else if (!this.frequencyPenalty.equals(other.frequencyPenalty)) { + return false; + } + if (this.maxTokens == null) { + if (other.maxTokens != null) { + return false; + } + } + else if (!this.maxTokens.equals(other.maxTokens)) { + return false; + } + if (this.presencePenalty == null) { + if (other.presencePenalty != null) { + return false; + } + } + else if (!this.presencePenalty.equals(other.presencePenalty)) { + return false; + } + if (this.responseFormat == null) { + if (other.responseFormat != null) { + return false; + } + } + else if (!this.responseFormat.equals(other.responseFormat)) { + return false; + } + if (this.stop == null) { + if (other.stop != null) { + return false; + } + } + else if (!this.stop.equals(other.stop)) { + return false; + } + if (this.temperature == null) { + if (other.temperature != null) { + return false; + } + } + else if (!this.temperature.equals(other.temperature)) { + return false; + } + if (this.topP == null) { + if (other.topP != null) { + return false; + } + } + else if (!this.topP.equals(other.topP)) { + return false; + } + return true; + } + + @Override + public SolarChatOptions copy() { + return fromOptions(this); + } + + public static class Builder { + + protected SolarChatOptions options; + + public Builder() { + this.options = new SolarChatOptions(); + } + + public Builder(SolarChatOptions options) { + this.options = options; + } + + public Builder withModel(String model) { + this.options.model = model; + return this; + } + + public Builder withFrequencyPenalty(Double frequencyPenalty) { + this.options.frequencyPenalty = frequencyPenalty; + return this; + } + + public Builder withMaxTokens(Integer maxTokens) { + this.options.maxTokens = maxTokens; + return this; + } + + public Builder withPresencePenalty(Double presencePenalty) { + this.options.presencePenalty = presencePenalty; + return this; + } + + public Builder withResponseFormat(SolarApi.ChatCompletionRequest.ResponseFormat responseFormat) { + this.options.responseFormat = responseFormat; + return this; + } + + public Builder withStop(List stop) { + this.options.stop = stop; + return this; + } + + public Builder withTemperature(Double temperature) { + this.options.temperature = temperature; + return this; + } + + public Builder withTopP(Double topP) { + this.options.topP = topP; + return this; + } + + public SolarChatOptions build() { + return this.options; + } + + } + +} diff --git a/models/spring-ai-solar/src/main/java/org/springframework/ai/solar/SolarEmbeddingModel.java b/models/spring-ai-solar/src/main/java/org/springframework/ai/solar/SolarEmbeddingModel.java new file mode 100644 index 00000000000..95ce091a05a --- /dev/null +++ b/models/spring-ai-solar/src/main/java/org/springframework/ai/solar/SolarEmbeddingModel.java @@ -0,0 +1,192 @@ +package org.springframework.ai.solar; + +import java.util.List; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.document.Document; +import org.springframework.ai.document.MetadataMode; +import org.springframework.ai.embedding.AbstractEmbeddingModel; +import org.springframework.ai.embedding.Embedding; +import org.springframework.ai.embedding.EmbeddingOptions; +import org.springframework.ai.embedding.EmbeddingRequest; +import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.ai.embedding.EmbeddingResponseMetadata; +import org.springframework.ai.embedding.observation.DefaultEmbeddingModelObservationConvention; +import org.springframework.ai.embedding.observation.EmbeddingModelObservationContext; +import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention; +import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.ai.solar.api.SolarApi; +import org.springframework.ai.solar.api.common.SolarConstants; +import org.springframework.ai.solar.metadata.SolarUsage; +import org.springframework.lang.Nullable; +import org.springframework.retry.support.RetryTemplate; +import org.springframework.util.Assert; + +import io.micrometer.observation.ObservationRegistry; + +public class SolarEmbeddingModel extends AbstractEmbeddingModel { + + private static final Logger logger = LoggerFactory.getLogger(SolarEmbeddingModel.class); + + private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention(); + + private final SolarEmbeddingOptions defaultOptions; + + private final RetryTemplate retryTemplate; + + private final SolarApi solarApi; + + private final MetadataMode metadataMode; + + /** + * Observation registry used for instrumentation. + */ + private final ObservationRegistry observationRegistry; + + /** + * Conventions to use for generating observations. + */ + private EmbeddingModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION; + + /** + * Constructor for the SolarEmbeddingModel class. + * @param solarApi The SolarApi instance to use for making API requests. + */ + public SolarEmbeddingModel(SolarApi solarApi) { + this(solarApi, MetadataMode.EMBED); + } + + /** + * Initializes a new instance of the SolarEmbeddingModel class. + * @param solarApi The SolarApi instance to use for making API requests. + * @param metadataMode The mode for generating metadata. + */ + public SolarEmbeddingModel(SolarApi solarApi, MetadataMode metadataMode) { + this(solarApi, metadataMode, + SolarEmbeddingOptions.builder().withModel(SolarApi.DEFAULT_EMBEDDING_MODEL).build()); + } + + /** + * Initializes a new instance of the SolarEmbeddingModel class. + * @param solarApi The SolarApi instance to use for making API requests. + * @param metadataMode The mode for generating metadata. + * @param SolarEmbeddingOptions The options for Solar embedding. + */ + public SolarEmbeddingModel(SolarApi solarApi, MetadataMode metadataMode, + SolarEmbeddingOptions SolarEmbeddingOptions) { + this(solarApi, metadataMode, SolarEmbeddingOptions, RetryUtils.DEFAULT_RETRY_TEMPLATE); + } + + /** + * Initializes a new instance of the SolarEmbeddingModel class. + * @param solarApi The SolarApi instance to use for making API requests. + * @param metadataMode The mode for generating metadata. + * @param SolarEmbeddingOptions The options for Solar embedding. + * @param retryTemplate - The RetryTemplate for retrying failed API requests. + */ + public SolarEmbeddingModel(SolarApi solarApi, MetadataMode metadataMode, + SolarEmbeddingOptions SolarEmbeddingOptions, RetryTemplate retryTemplate) { + this(solarApi, metadataMode, SolarEmbeddingOptions, retryTemplate, ObservationRegistry.NOOP); + } + + /** + * Initializes a new instance of the SolarEmbeddingModel class. + * @param solarApi - The SolarApi instance to use for making API requests. + * @param metadataMode - The mode for generating metadata. + * @param options - The options for Solar embedding. + * @param retryTemplate - The RetryTemplate for retrying failed API requests. + * @param observationRegistry - The ObservationRegistry used for instrumentation. + */ + public SolarEmbeddingModel(SolarApi solarApi, MetadataMode metadataMode, SolarEmbeddingOptions options, + RetryTemplate retryTemplate, ObservationRegistry observationRegistry) { + Assert.notNull(solarApi, "SolarApi must not be null"); + Assert.notNull(metadataMode, "metadataMode must not be null"); + Assert.notNull(options, "options must not be null"); + Assert.notNull(retryTemplate, "retryTemplate must not be null"); + Assert.notNull(observationRegistry, "observationRegistry must not be null"); + + this.solarApi = solarApi; + this.metadataMode = metadataMode; + this.defaultOptions = options; + this.retryTemplate = retryTemplate; + this.observationRegistry = observationRegistry; + } + + @Override + public float[] embed(Document document) { + Assert.notNull(document, "Document must not be null"); + return this.embed(document.getFormattedContent(this.metadataMode)); + } + + @Override + public EmbeddingResponse call(EmbeddingRequest request) { + SolarEmbeddingOptions requestOptions = mergeOptions(request.getOptions(), this.defaultOptions); + SolarApi.EmbeddingRequest apiRequest = new SolarApi.EmbeddingRequest(request.getInstructions(), + requestOptions.getModel()); + + var observationContext = EmbeddingModelObservationContext.builder() + .embeddingRequest(request) + .provider(SolarConstants.PROVIDER_NAME) + .requestOptions(requestOptions) + .build(); + + return EmbeddingModelObservationDocumentation.EMBEDDING_MODEL_OPERATION + .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, + this.observationRegistry) + .observe(() -> { + SolarApi.EmbeddingList apiEmbeddingResponse = this.retryTemplate + .execute(ctx -> this.solarApi.embeddings(apiRequest).getBody()); + + if (apiEmbeddingResponse == null) { + logger.warn("No embeddings returned for request: {}", request); + return new EmbeddingResponse(List.of()); + } + + if (apiEmbeddingResponse.errorNsg() != null) { + logger.error("Error message returned for request: {}", apiEmbeddingResponse.errorNsg()); + throw new RuntimeException("Embedding failed: error code:" + apiEmbeddingResponse.errorCode() + + ", message:" + apiEmbeddingResponse.errorNsg()); + } + + var metadata = new EmbeddingResponseMetadata(apiRequest.model(), + SolarUsage.from(apiEmbeddingResponse.usage())); + + List embeddings = apiEmbeddingResponse.data() + .stream() + .map(e -> new Embedding(e.embedding(), e.index())) + .toList(); + + EmbeddingResponse embeddingResponse = new EmbeddingResponse(embeddings, metadata); + + observationContext.setResponse(embeddingResponse); + + return embeddingResponse; + }); + } + + /** + * Merge runtime and default {@link EmbeddingOptions} to compute the final options to + * use in the request. + */ + private SolarEmbeddingOptions mergeOptions(@Nullable EmbeddingOptions runtimeOptions, + SolarEmbeddingOptions defaultOptions) { + var runtimeOptionsForProvider = ModelOptionsUtils.copyToTarget(runtimeOptions, EmbeddingOptions.class, + SolarEmbeddingOptions.class); + + if (runtimeOptionsForProvider == null) { + return defaultOptions; + } + + return SolarEmbeddingOptions.builder() + .withModel(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getModel(), defaultOptions.getModel())) + .build(); + } + + public void setObservationConvention(EmbeddingModelObservationConvention observationConvention) { + this.observationConvention = observationConvention; + } + +} diff --git a/models/spring-ai-solar/src/main/java/org/springframework/ai/solar/SolarEmbeddingOptions.java b/models/spring-ai-solar/src/main/java/org/springframework/ai/solar/SolarEmbeddingOptions.java new file mode 100644 index 00000000000..c36ab4c7fdc --- /dev/null +++ b/models/spring-ai-solar/src/main/java/org/springframework/ai/solar/SolarEmbeddingOptions.java @@ -0,0 +1,60 @@ +package org.springframework.ai.solar; + +import org.springframework.ai.embedding.EmbeddingOptions; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * This class represents the options for Solar embedding. + * + * @author Seunghyeon Ji + * @since 1.0 + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class SolarEmbeddingOptions implements EmbeddingOptions { + + // @formatter:off + /** + * ID of the model to use. + */ + private @JsonProperty("model") String model; + + public static Builder builder() { + return new Builder(); + } + + @Override + public String getModel() { + return this.model; + } + + public void setModel(String model) { + this.model = model; + } + + @Override + @JsonIgnore + public Integer getDimensions() { + return null; + } + + public static class Builder { + + protected SolarEmbeddingOptions options; + + public Builder() { + this.options = new SolarEmbeddingOptions(); + } + + public Builder withModel(String model) { + this.options.setModel(model); + return this; + } + + public SolarEmbeddingOptions build() { + return this.options; + } + } +} diff --git a/models/spring-ai-solar/src/main/java/org/springframework/ai/solar/aot/SolarRuntimeHints.java b/models/spring-ai-solar/src/main/java/org/springframework/ai/solar/aot/SolarRuntimeHints.java new file mode 100644 index 00000000000..3197f09b83b --- /dev/null +++ b/models/spring-ai-solar/src/main/java/org/springframework/ai/solar/aot/SolarRuntimeHints.java @@ -0,0 +1,28 @@ +package org.springframework.ai.solar.aot; + +import static org.springframework.ai.aot.AiRuntimeHints.*; + +import org.springframework.ai.solar.api.SolarApi; +import org.springframework.aot.hint.MemberCategory; +import org.springframework.aot.hint.RuntimeHints; +import org.springframework.aot.hint.RuntimeHintsRegistrar; +import org.springframework.lang.NonNull; +import org.springframework.lang.Nullable; + +/** + * The SolarRuntimeHints class is responsible for registering runtime hints for Solar API + * classes. + * + * @author Seungheon Ji + */ +public class SolarRuntimeHints implements RuntimeHintsRegistrar { + + @Override + public void registerHints(@NonNull RuntimeHints hints, @Nullable ClassLoader classLoader) { + var mcs = MemberCategory.values(); + for (var tr : findJsonAnnotatedClassesInPackage(SolarApi.class)) { + hints.reflection().registerType(tr, mcs); + } + } + +} diff --git a/models/spring-ai-solar/src/main/java/org/springframework/ai/solar/api/SolarApi.java b/models/spring-ai-solar/src/main/java/org/springframework/ai/solar/api/SolarApi.java new file mode 100644 index 00000000000..cbb0ffd637f --- /dev/null +++ b/models/spring-ai-solar/src/main/java/org/springframework/ai/solar/api/SolarApi.java @@ -0,0 +1,457 @@ +package org.springframework.ai.solar.api; + +import java.util.List; +import java.util.function.Consumer; +import java.util.function.Predicate; + +import org.springframework.ai.retry.RetryUtils; +import org.springframework.ai.solar.api.common.SolarConstants; +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.web.client.ResponseErrorHandler; +import org.springframework.web.client.RestClient; +import org.springframework.web.reactive.function.client.WebClient; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +public class SolarApi { + + public static final String DEFAULT_CHAT_MODEL = ChatModel.SOLAR_PRO.getValue(); + + public static final String DEFAULT_EMBEDDING_MODEL = EmbeddingModel.EMBEDDING_QUERY.getValue(); + + private static final Predicate SSE_DONE_PREDICATE = ChatCompletionChunk::end; + + private final RestClient restClient; + + private final WebClient webClient; + + /** + * Create a new chat completion api with default base URL. + * @param apiKey Solar api key. + */ + public SolarApi(String apiKey) { + this(SolarConstants.DEFAULT_BASE_URL, apiKey); + } + + /** + * Create a new chat completion api. + * @param baseUrl api base URL. + * @param apiKey Solar api key. + */ + public SolarApi(String baseUrl, String apiKey) { + this(baseUrl, apiKey, RestClient.builder()); + } + + /** + * Create a new chat completion api. + * @param baseUrl api base URL. + * @param apiKey Solar api key. + * @param restClientBuilder RestClient builder. + */ + public SolarApi(String baseUrl, String apiKey, RestClient.Builder restClientBuilder) { + this(baseUrl, apiKey, restClientBuilder, RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER); + } + + /** + * Create a new chat completion api. + * @param baseUrl api base URL. + * @param apiKey Solar api key. + * @param restClientBuilder RestClient builder. + * @param responseErrorHandler Response error handler. + */ + public SolarApi(String baseUrl, String apiKey, RestClient.Builder restClientBuilder, + ResponseErrorHandler responseErrorHandler) { + this(baseUrl, apiKey, restClientBuilder, WebClient.builder(), responseErrorHandler); + } + + /** + * Create a new chat completion api. + * @param baseUrl api base URL. + * @param apiKey Solar api key. + * @param restClientBuilder RestClient builder. + * @param webClientBuilder WebClient builder. + * @param responseErrorHandler Response error handler. + */ + public SolarApi(String baseUrl, String apiKey, RestClient.Builder restClientBuilder, + WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler) { + Consumer finalHeaders = h -> { + h.setBearerAuth(apiKey); + h.setContentType(MediaType.APPLICATION_JSON); + }; + this.restClient = restClientBuilder.baseUrl(baseUrl) + .defaultHeaders(finalHeaders) + .defaultStatusHandler(responseErrorHandler) + .build(); + this.webClient = webClientBuilder.baseUrl(baseUrl).defaultHeaders(finalHeaders).build(); + } + + /** + * Creates a model response for the given chat conversation. + * @param chatRequest The chat completion request. + * @return Entity response with {@link ChatCompletion} as a body and HTTP status code + * and headers. + */ + public ResponseEntity chatCompletionEntity(ChatCompletionRequest chatRequest) { + Assert.notNull(chatRequest, "The request body can not be null."); + Assert.isTrue(!chatRequest.stream(), "Request must set the stream property to false."); + + return this.restClient.post() + .uri("/v1/solar/chat/completions") + .body(chatRequest) + .retrieve() + .toEntity(ChatCompletion.class); + } + + /** + * Creates an embedding vector representing the input text or token array. + * @param embeddingRequest The embedding request. + * @return Returns list of {@link Embedding} wrapped in {@link EmbeddingList}. + */ + public ResponseEntity embeddings(EmbeddingRequest embeddingRequest) { + + Assert.notNull(embeddingRequest, "The request body can not be null."); + + // Input text to embed, encoded as a string or array of tokens. To embed multiple + // inputs in a single + // request, pass an array of strings or array of token arrays. + Assert.notNull(embeddingRequest.texts(), "The input can not be null."); + + // The input must not an empty string, and any array must be 16 dimensions or + // less. + Assert.isTrue(!CollectionUtils.isEmpty(embeddingRequest.texts()), "The input list can not be empty."); + Assert.isTrue(embeddingRequest.texts().size() <= 16, "The list must be 16 dimensions or less"); + + return this.restClient.post() + .uri("/v1/solar/embeddings") + .body(embeddingRequest) + .retrieve() + .toEntity(new ParameterizedTypeReference<>() { + }); + } + + /** + * Solar Chat Completion Models: + * Solar Model. + */ + public enum ChatModel { + + SOLAR_PRO("solar-pro"), SOLAR_MINI("solar-mini"), SOLAR_MINI_JA("solar-mini-ja"); + + public final String value; + + ChatModel(String value) { + this.value = value; + } + + public String getValue() { + return this.value; + } + + } + + /** + * Solar Embeddings Models: + * Embeddings. + */ + public enum EmbeddingModel { + + EMBEDDING_QUERY("embedding-query"), EMBEDDING_PASSAGE("embedding-passage"); + + public final String value; + + EmbeddingModel(String value) { + this.value = value; + } + + public String getValue() { + return this.value; + } + + } + + /** + * Creates a model response for the given chat conversation. + * + * @param messages A list of messages comprising the conversation so far. + * @param model ID of the model to use. + * @param frequencyPenalty Number between -2.0 and 2.0. Positive values penalize new + * tokens based on their existing frequency in the text so far, decreasing the model's + * likelihood to repeat the same line verbatim. + * @param maxTokens The maximum number of tokens to generate in the chat completion. + * The total length of input tokens and generated tokens is limited by the model's + * context length. appear in the text so far, increasing the model's likelihood to + * talk about new topics. + * @param responseFormat An object specifying the format that the model must output. + * Setting to { "type": "json_object" } enables JSON mode, which guarantees the + * message the model generates is valid JSON. + * @param stop Up to 4 sequences where the API will stop generating further tokens. + * @param stream If set, partial message deltas will be sent.Tokens will be sent as + * data-only server-sent events as they become available, with the stream terminated + * by a data: [DONE] message. + * @param temperature What sampling temperature to use, between 0 and 1. Higher values + * like 0.8 will make the output more random, while lower values like 0.2 will make it + * more focused and deterministic. We generally recommend altering this or top_p but + * not both. + * @param topP An alternative to sampling with temperature, called nucleus sampling, + * where the model considers the results of the tokens with top_p probability mass. So + * 0.1 means only the tokens comprising the top 10% probability mass are considered. + * We generally recommend altering this or temperature but not both. + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + public record ChatCompletionRequest(@JsonProperty("messages") List messages, + @JsonProperty("system") String system, @JsonProperty("model") String model, + @JsonProperty("frequency_penalty") Double frequencyPenalty, + @JsonProperty("max_output_tokens") Integer maxTokens, + @JsonProperty("presence_penalty") Double presencePenalty, + @JsonProperty("response_format") ResponseFormat responseFormat, @JsonProperty("stop") List stop, + @JsonProperty("stream") Boolean stream, @JsonProperty("temperature") Double temperature, + @JsonProperty("top_p") Double topP) { + + /** + * Shortcut constructor for a chat completion request with the given messages and + * model. + * @param messages A list of messages comprising the conversation so far. + * @param model ID of the model to use. + * @param temperature What sampling temperature to use, between 0 and 1. + */ + public ChatCompletionRequest(List messages, String system, String model, + Double temperature) { + this(messages, system, model, null, null, null, null, null, false, temperature, null); + } + + /** + * Shortcut constructor for a chat completion request with the given messages, + * model and control for streaming. + * @param messages A list of messages comprising the conversation so far. + * @param model ID of the model to use. + * @param stream If set, partial message deltas will be sent.Tokens will be sent + * as data-only server-sent events as they become available, with the stream + * terminated by a data: [DONE] message. + */ + public ChatCompletionRequest(List messages, String model, boolean stream) { + this(messages, null, model, null, null, null, null, null, stream, null, null); + } + + /** + * An object specifying the format that the model must output. + * + * @param type Must be one of 'text' or 'json_object'. + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + public record ResponseFormat(@JsonProperty("type") String type) { + } + } + + /** + * Message comprising the conversation. + * + * @param rawContent The contents of the message. Can be a {@link String}. The + * response message content is always a {@link String}. + * @param role The role of the messages author. Could be one of the {@link Role} + * types. + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + public record ChatCompletionMessage(@JsonProperty("content") Object rawContent, @JsonProperty("role") Role role) { + + /** + * Get message content as String. + */ + public String content() { + if (this.rawContent == null) { + return null; + } + if (this.rawContent instanceof String text) { + return text; + } + throw new IllegalStateException("The content is not a string!"); + } + + /** + * The role of the author of this message. + */ + public enum Role { + + /** + * System message. + */ + @JsonProperty("system") + SYSTEM, + /** + * User message. + */ + @JsonProperty("user") + USER, + /** + * Assistant message. + */ + @JsonProperty("assistant") + ASSISTANT + + } + } + + /** + * Represents a chat completion response returned by model, based on the provided + * input. + * + * @param id A unique identifier for the chat completion. + * @param created The Unix timestamp (in seconds) of when the chat completion was + * created. used in conjunction with the seed request parameter to understand when + * backend changes have been made that might impact determinism. + * @param object The object type, which is always chat.completion. + * @param usage Usage statistics for the completion request. + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + public record ChatCompletion(@JsonProperty("id") String id, @JsonProperty("object") String object, + @JsonProperty("created") Long created, @JsonProperty("model") String model, + @JsonProperty("choices") List choices, @JsonProperty("usage") Usage usage, + @JsonProperty("system_fingerprint") String systemFingerprint) { + public record Choice(@JsonProperty("index") int index, @JsonProperty("message") Message message, + @JsonProperty("logprobs") Object logprobs, @JsonProperty("finish_reason") String finishReason) { + } + + public record Message(@JsonProperty("role") String role, @JsonProperty("content") String content) { + } + } + + /** + * Usage statistics for the completion request. + * + * @param promptTokens Number of tokens in the prompt. + * @param totalTokens Total number of tokens used in the request (prompt + + * completion). + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + public record Usage(@JsonProperty("completion_tokens") Integer completionTokens, + @JsonProperty("prompt_tokens") Integer promptTokens, @JsonProperty("total_tokens") Integer totalTokens) { + } + + /** + * Represents a streamed chunk of a chat completion response returned by model, based + * on the provided input. + * + * @param id A unique identifier for the chat completion. Each chunk has the same ID. + * @param object The object type, which is always 'chat.completion.chunk'. + * @param created The Unix timestamp (in seconds) of when the chat completion was + * created. Each chunk has the same timestamp. + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + public record ChatCompletionChunk(@JsonProperty("id") String id, @JsonProperty("object") String object, + @JsonProperty("created") Long created, @JsonProperty("choices") List choices, + @JsonProperty("finish_reason") String finishReason, @JsonProperty("is_end") Boolean end, + @JsonProperty("usage") Usage usage, @JsonProperty("system_finger_print") String systemFingerPrint) { + public record Choice(@JsonProperty("index") int index, @JsonProperty("delta") Delta delta, + @JsonProperty("logprobs") Object logprobs, @JsonProperty("finish_reason") String finishReason) { + } + + public record Delta(@JsonProperty("role") String role, @JsonProperty("content") String content) { + } + } + + /** + * Creates an embedding vector representing the input text. + * + * @param texts Input text to embed, encoded as a string or array of tokens. + * @param user A unique identifier representing your end-user, which can help Solar to + * monitor and detect abuse. + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + public record EmbeddingRequest(@JsonProperty("input") List texts, @JsonProperty("model") String model, + @JsonProperty("user_id") String user) { + /** + * Create an embedding request with the given input. Embedding model is set to + * 'bge_large_zh'. + * @param text Input text to embed. + */ + public EmbeddingRequest(String text) { + this(List.of(text), DEFAULT_EMBEDDING_MODEL, null); + } + + /** + * Create an embedding request with the given input. + * @param text Input text to embed. + * @param model ID of the model to use. + * @param userId A unique identifier representing your end-user, which can help + * Solar to monitor and detect abuse. + */ + public EmbeddingRequest(String text, String model, String userId) { + this(List.of(text), model, userId); + } + + /** + * Create an embedding request with the given input. Embedding model is set to + * 'bge_large_zh'. + * @param texts Input text to embed. + */ + public EmbeddingRequest(List texts) { + this(texts, DEFAULT_EMBEDDING_MODEL, null); + } + + /** + * Create an embedding request with the given input. + * @param texts Input text to embed. + * @param model ID of the model to use. + */ + public EmbeddingRequest(List texts, String model) { + this(texts, model, null); + } + } + + /** + * Represents an embedding vector returned by embedding endpoint. + * + * @param index The index of the embedding in the list of embeddings. + * @param embedding The embedding vector, which is a list of floats. The length of + * vector depends on the model. + * @param object The object type, which is always 'embedding'. + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + public record Embedding( + // @formatter:off + @JsonProperty("index") Integer index, + @JsonProperty("embedding") float[] embedding, + @JsonProperty("object") String object) { + // @formatter:on + + /** + * Create an embedding with the given index, embedding and object type set to + * 'embedding'. + * @param index The index of the embedding in the list of embeddings. + * @param embedding The embedding vector, which is a list of floats. The length of + * vector depends on the model. + */ + public Embedding(Integer index, float[] embedding) { + this(index, embedding, "embedding"); + } + } + + /** + * List of multiple embedding responses. + * + * @param object Must have value "embedding_list". + * @param data List of entities. + * @param model ID of the model to use. + * @param usage Usage statistics for the completion request. + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + public record EmbeddingList( + // @formatter:off + @JsonProperty("object") String object, + @JsonProperty("data") List data, + @JsonProperty("model") String model, + @JsonProperty("error_code") String errorCode, + @JsonProperty("error_msg") String errorNsg, + @JsonProperty("usage") Usage usage) { + // @formatter:on + } + +} diff --git a/models/spring-ai-solar/src/main/java/org/springframework/ai/solar/api/common/SolarConstants.java b/models/spring-ai-solar/src/main/java/org/springframework/ai/solar/api/common/SolarConstants.java new file mode 100644 index 00000000000..c66d996fcaa --- /dev/null +++ b/models/spring-ai-solar/src/main/java/org/springframework/ai/solar/api/common/SolarConstants.java @@ -0,0 +1,19 @@ +package org.springframework.ai.solar.api.common; + +import org.springframework.ai.observation.conventions.AiProvider; + +/** + * Common value constants for Solar api. + * + * @author Seunghyeon Ji + */ +public class SolarConstants { + + public static final String DEFAULT_BASE_URL = "https://api.upstage.ai"; + + public static final String PROVIDER_NAME = AiProvider.SOLAR.value(); + + private SolarConstants() { + } + +} diff --git a/models/spring-ai-solar/src/main/java/org/springframework/ai/solar/metadata/SolarUsage.java b/models/spring-ai-solar/src/main/java/org/springframework/ai/solar/metadata/SolarUsage.java new file mode 100644 index 00000000000..8d2e440a56d --- /dev/null +++ b/models/spring-ai-solar/src/main/java/org/springframework/ai/solar/metadata/SolarUsage.java @@ -0,0 +1,65 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.solar.metadata; + +import org.springframework.ai.chat.metadata.Usage; +import org.springframework.ai.solar.api.SolarApi; +import org.springframework.util.Assert; + +/** + * {@link Usage} implementation for {@literal Solar}. + * + * @author Seunghyeon Ji + */ +public class SolarUsage implements Usage { + + private final SolarApi.Usage usage; + + protected SolarUsage(SolarApi.Usage usage) { + Assert.notNull(usage, "Solar Usage must not be null"); + this.usage = usage; + } + + public static SolarUsage from(SolarApi.Usage usage) { + return new SolarUsage(usage); + } + + protected SolarApi.Usage getUsage() { + return this.usage; + } + + @Override + public Long getPromptTokens() { + return getUsage().promptTokens().longValue(); + } + + @Override + public Long getGenerationTokens() { + return 0L; + } + + @Override + public Long getTotalTokens() { + return getUsage().totalTokens().longValue(); + } + + @Override + public String toString() { + return getUsage().toString(); + } + +} diff --git a/models/spring-ai-solar/src/main/resources/META-INF/spring/aot.factories b/models/spring-ai-solar/src/main/resources/META-INF/spring/aot.factories new file mode 100644 index 00000000000..8ddbd00cf52 --- /dev/null +++ b/models/spring-ai-solar/src/main/resources/META-INF/spring/aot.factories @@ -0,0 +1,2 @@ +org.springframework.aot.hint.RuntimeHintsRegistrar=\ + org.springframework.ai.solar.aot.SolarRuntimeHints diff --git a/models/spring-ai-solar/src/test/java/ChatCompletionRequestTests.java b/models/spring-ai-solar/src/test/java/ChatCompletionRequestTests.java new file mode 100644 index 00000000000..243dd744c7f --- /dev/null +++ b/models/spring-ai-solar/src/test/java/ChatCompletionRequestTests.java @@ -0,0 +1,53 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.solar.SolarChatModel; +import org.springframework.ai.solar.SolarChatOptions; +import org.springframework.ai.solar.api.SolarApi; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Seunghyeon Ji + */ +public class ChatCompletionRequestTests { + + @Test + public void createRequestWithChatOptions() { + var client = new SolarChatModel(new SolarApi("TEST", "TEST"), + SolarChatOptions.builder().withModel("DEFAULT_MODEL").withTemperature(66.6).build()); + + var request = client.createRequest(new Prompt("Test message content"), false); + + assertThat(request.messages()).hasSize(1); + assertThat(request.stream()).isFalse(); + + assertThat(request.model()).isEqualTo("DEFAULT_MODEL"); + assertThat(request.temperature()).isEqualTo(66.6); + + request = client.createRequest(new Prompt("Test message content", + SolarChatOptions.builder().withModel("PROMPT_MODEL").withTemperature(99.9).build()), true); + + assertThat(request.messages()).hasSize(1); + assertThat(request.stream()).isTrue(); + + assertThat(request.model()).isEqualTo("PROMPT_MODEL"); + assertThat(request.temperature()).isEqualTo(99.9); + } + +} diff --git a/models/spring-ai-solar/src/test/java/api/SolarApiIT.java b/models/spring-ai-solar/src/test/java/api/SolarApiIT.java new file mode 100644 index 00000000000..ab314aae4d7 --- /dev/null +++ b/models/spring-ai-solar/src/test/java/api/SolarApiIT.java @@ -0,0 +1,63 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package api; + +import static org.assertj.core.api.Assertions.*; + +import java.util.List; +import java.util.Objects; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariables; +import org.springframework.ai.ResourceUtils; +import org.springframework.ai.solar.api.SolarApi; +import org.springframework.http.ResponseEntity; +import org.stringtemplate.v4.ST; + +import reactor.core.publisher.Flux; + +/** + * @author Seunghyeon Ji + */ +@EnabledIfEnvironmentVariables({ @EnabledIfEnvironmentVariable(named = "SOLAR_API_KEY", matches = ".+") }) +public class SolarApiIT { + + SolarApi solarApi = new SolarApi(System.getenv("SOLAR_API_KEY")); + + @Test + void chatCompletionEntity() { + SolarApi.ChatCompletionMessage chatCompletionMessage = new SolarApi.ChatCompletionMessage("Hello world", + SolarApi.ChatCompletionMessage.Role.USER); + ResponseEntity response = this.solarApi.chatCompletionEntity( + new SolarApi.ChatCompletionRequest(List.of(chatCompletionMessage), SolarApi.DEFAULT_CHAT_MODEL, false)); + + assertThat(response).isNotNull(); + assertThat(response.getBody()).isNotNull(); + } + + @Test + void embeddings() { + ResponseEntity response = this.solarApi + .embeddings(new SolarApi.EmbeddingRequest("Hello world")); + + assertThat(response).isNotNull(); + assertThat(Objects.requireNonNull(response.getBody()).data()).hasSize(1); + assertThat(response.getBody().data().get(0).embedding()).hasSize(4096); + } + +} diff --git a/models/spring-ai-solar/src/test/java/api/SolarRetryTests.java b/models/spring-ai-solar/src/test/java/api/SolarRetryTests.java new file mode 100644 index 00000000000..88465c468d7 --- /dev/null +++ b/models/spring-ai-solar/src/test/java/api/SolarRetryTests.java @@ -0,0 +1,151 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package api; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import reactor.core.publisher.Flux; + +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.document.MetadataMode; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.ai.retry.TransientAiException; +import org.springframework.ai.solar.SolarChatModel; +import org.springframework.ai.solar.SolarChatOptions; +import org.springframework.ai.solar.SolarEmbeddingModel; +import org.springframework.ai.solar.SolarEmbeddingOptions; +import org.springframework.ai.solar.api.SolarApi; +import org.springframework.http.ResponseEntity; +import org.springframework.retry.RetryCallback; +import org.springframework.retry.RetryContext; +import org.springframework.retry.RetryListener; +import org.springframework.retry.support.RetryTemplate; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.BDDMockito.given; + +/** + * @author Seunghyeon Ji + */ +@ExtendWith(MockitoExtension.class) +public class SolarRetryTests { + + private TestRetryListener retryListener; + + private @Mock SolarApi solarApi; + + private SolarChatModel chatClient; + + private SolarEmbeddingModel embeddingClient; + + @BeforeEach + public void beforeEach() { + RetryTemplate retryTemplate = RetryUtils.SHORT_RETRY_TEMPLATE; + this.retryListener = new TestRetryListener(); + retryTemplate.registerListener(this.retryListener); + + this.chatClient = new SolarChatModel(this.solarApi, SolarChatOptions.builder().build(), retryTemplate); + this.embeddingClient = new SolarEmbeddingModel(this.solarApi, MetadataMode.EMBED, + SolarEmbeddingOptions.builder().build(), retryTemplate); + } + + @Test + public void solarChatTransientError() { + List choices = List.of(new SolarApi.ChatCompletion.Choice(0, + new SolarApi.ChatCompletion.Message("assistant", "Response"), null, "STOP")); + + SolarApi.ChatCompletion expectedChatCompletion = new SolarApi.ChatCompletion("id", "chat.completion", 666L, + SolarApi.DEFAULT_CHAT_MODEL, choices, new SolarApi.Usage(10, 10, 10), null); + + given(this.solarApi.chatCompletionEntity(isA(SolarApi.ChatCompletionRequest.class))) + .willThrow(new TransientAiException("Transient Error 1")) + .willThrow(new TransientAiException("Transient Error 2")) + .willReturn(ResponseEntity.of(Optional.of(expectedChatCompletion))); + + var result = this.chatClient.call(new Prompt("text")); + + assertThat(result).isNotNull(); + assertThat(result.getResult().getOutput().getContent()).isEqualTo("Response"); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); + } + + @Test + public void solarChatNonTransientError() { + given(this.solarApi.chatCompletionEntity(isA(SolarApi.ChatCompletionRequest.class))) + .willThrow(new RuntimeException("Non Transient Error")); + assertThrows(RuntimeException.class, () -> this.chatClient.call(new Prompt("text"))); + } + + @Test + public void solarEmbeddingTransientError() { + SolarApi.Embedding embedding = new SolarApi.Embedding(1, new float[] { 9.9f, 8.8f }); + SolarApi.EmbeddingList expectedEmbeddings = new SolarApi.EmbeddingList("embedding_list", List.of(embedding), + "model", null, null, new SolarApi.Usage(10, 10, 10)); + + given(this.solarApi.embeddings(isA(SolarApi.EmbeddingRequest.class))) + .willThrow(new TransientAiException("Transient Error 1")) + .willThrow(new TransientAiException("Transient Error 2")) + .willReturn(ResponseEntity.of(Optional.of(expectedEmbeddings))); + + var result = this.embeddingClient + .call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null)); + + assertThat(result).isNotNull(); + assertThat(result.getResult().getOutput()).isEqualTo(new float[] { 9.9f, 8.8f }); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); + } + + @Test + public void solarEmbeddingNonTransientError() { + given(this.solarApi.embeddings(isA(SolarApi.EmbeddingRequest.class))) + .willThrow(new RuntimeException("Non Transient Error")); + assertThrows(RuntimeException.class, () -> this.embeddingClient + .call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null))); + } + + private static class TestRetryListener implements RetryListener { + + int onErrorRetryCount = 0; + + int onSuccessRetryCount = 0; + + @Override + public void onSuccess(RetryContext context, RetryCallback callback, T result) { + this.onSuccessRetryCount = context.getRetryCount(); + } + + @Override + public void onError(RetryContext context, RetryCallback callback, + Throwable throwable) { + this.onErrorRetryCount = context.getRetryCount(); + } + + } + +} diff --git a/models/spring-ai-solar/src/test/java/chat/SolarChatModelIT.java b/models/spring-ai-solar/src/test/java/chat/SolarChatModelIT.java new file mode 100644 index 00000000000..b1c7c894815 --- /dev/null +++ b/models/spring-ai-solar/src/test/java/chat/SolarChatModelIT.java @@ -0,0 +1,68 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package chat; + +import static org.assertj.core.api.Assertions.*; + +import java.util.List; +import java.util.Map; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariables; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.StreamingChatModel; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.core.io.Resource; + +/** + * @author Seunghyeon Ji + */ +@SpringBootTest(classes = SolarTestConfiguration.class) +@EnabledIfEnvironmentVariables({ @EnabledIfEnvironmentVariable(named = "SOLAR_API_KEY", matches = ".+") }) +public class SolarChatModelIT { + + @Autowired + protected ChatModel chatModel; + + @Autowired + protected StreamingChatModel streamingChatModel; + + @Value("classpath:/prompts/system-message.st") + private Resource systemResource; + + @Test + void roleTest() { + UserMessage userMessage = new UserMessage( + "Tell me about three famous pirates from the Golden Age of Piracy in english, focusing on their original nicknames and what they did."); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); + Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); + Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); + ChatResponse response = this.chatModel.call(prompt); + + assertThat(response.getResults()).hasSize(1); + assertThat(response.getResults().get(0).getOutput().getContent()).contains("Blackbeard"); + } + +} diff --git a/models/spring-ai-solar/src/test/java/chat/SolarChatModelObservationIT.java b/models/spring-ai-solar/src/test/java/chat/SolarChatModelObservationIT.java new file mode 100644 index 00000000000..12550b73086 --- /dev/null +++ b/models/spring-ai-solar/src/test/java/chat/SolarChatModelObservationIT.java @@ -0,0 +1,140 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package chat; + +import static org.assertj.core.api.Assertions.*; +import static org.springframework.ai.chat.observation.ChatModelObservationDocumentation.*; + +import java.util.List; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariables; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.observation.conventions.AiOperationType; +import org.springframework.ai.observation.conventions.AiProvider; +import org.springframework.ai.solar.SolarChatModel; +import org.springframework.ai.solar.SolarChatOptions; +import org.springframework.ai.solar.api.SolarApi; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; +import org.springframework.retry.support.RetryTemplate; + +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; + +/** + * Integration tests for observation instrumentation in {@link SolarChatModel}. + * + * @author Seunghyeon Ji + */ +@SpringBootTest(classes = SolarChatModelObservationIT.Config.class) +@EnabledIfEnvironmentVariables({ @EnabledIfEnvironmentVariable(named = "SOLAR_API_KEY", matches = ".+") }) +public class SolarChatModelObservationIT { + + @Autowired + TestObservationRegistry observationRegistry; + + @Autowired + SolarChatModel chatModel; + + @BeforeEach + void beforeEach() { + this.observationRegistry.clear(); + } + + @Test + void observationForChatOperation() { + var options = SolarChatOptions.builder() + .withModel(SolarApi.DEFAULT_CHAT_MODEL) + .withFrequencyPenalty(0.0) + .withMaxTokens(2048) + .withPresencePenalty(0.0) + .withStop(List.of("this-is-the-end")) + .withTemperature(0.7) + .withTopP(1.0) + .build(); + + Prompt prompt = new Prompt("Why does a raven look like a desk?", options); + + ChatResponse chatResponse = this.chatModel.call(prompt); + assertThat(chatResponse.getResult().getOutput().getContent()).isNotEmpty(); + + ChatResponseMetadata responseMetadata = chatResponse.getMetadata(); + assertThat(responseMetadata).isNotNull(); + + validate(responseMetadata); + } + + private void validate(ChatResponseMetadata responseMetadata) { + TestObservationRegistryAssert.assertThat(this.observationRegistry) + .doesNotHaveAnyRemainingCurrentObservation() + .hasObservationWithNameEqualTo(DefaultChatModelObservationConvention.DEFAULT_NAME) + .that() + .hasContextualNameEqualTo("chat " + SolarApi.DEFAULT_CHAT_MODEL) + .hasLowCardinalityKeyValue(LowCardinalityKeyNames.AI_OPERATION_TYPE.asString(), + AiOperationType.CHAT.value()) + .hasLowCardinalityKeyValue(LowCardinalityKeyNames.AI_PROVIDER.asString(), AiProvider.SOLAR.value()) + .hasLowCardinalityKeyValue(LowCardinalityKeyNames.REQUEST_MODEL.asString(), SolarApi.DEFAULT_CHAT_MODEL) + .hasLowCardinalityKeyValue(LowCardinalityKeyNames.RESPONSE_MODEL.asString(), responseMetadata.getModel()) + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_FREQUENCY_PENALTY.asString(), "0.0") + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_MAX_TOKENS.asString(), "2048") + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_PRESENCE_PENALTY.asString(), "0.0") + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_STOP_SEQUENCES.asString(), + "[\"this-is-the-end\"]") + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_TEMPERATURE.asString(), "0.7") + .doesNotHaveHighCardinalityKeyValueWithKey(HighCardinalityKeyNames.REQUEST_TOP_K.asString()) + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_TOP_P.asString(), "1.0") + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.RESPONSE_ID.asString(), responseMetadata.getId()) + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_INPUT_TOKENS.asString(), + String.valueOf(responseMetadata.getUsage().getPromptTokens())) + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_OUTPUT_TOKENS.asString(), + String.valueOf(responseMetadata.getUsage().getGenerationTokens())) + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_TOTAL_TOKENS.asString(), + String.valueOf(responseMetadata.getUsage().getTotalTokens())) + .hasBeenStarted() + .hasBeenStopped(); + } + + @SpringBootConfiguration + static class Config { + + @Bean + public TestObservationRegistry observationRegistry() { + return TestObservationRegistry.create(); + } + + @Bean + public SolarApi solarApi() { + return new SolarApi(System.getenv("SOLAR_API_KEY")); + } + + @Bean + public SolarChatModel solarChatModel(SolarApi solarApi, TestObservationRegistry observationRegistry) { + return new SolarChatModel(solarApi, SolarChatOptions.builder().build(), RetryTemplate.defaultInstance(), + observationRegistry); + } + + } + +} diff --git a/models/spring-ai-solar/src/test/java/chat/SolarTestConfiguration.java b/models/spring-ai-solar/src/test/java/chat/SolarTestConfiguration.java new file mode 100644 index 00000000000..fb1e68cd9bd --- /dev/null +++ b/models/spring-ai-solar/src/test/java/chat/SolarTestConfiguration.java @@ -0,0 +1,41 @@ +package chat; + +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.solar.SolarChatModel; +import org.springframework.ai.solar.SolarEmbeddingModel; +import org.springframework.ai.solar.api.SolarApi; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.context.annotation.Bean; +import org.springframework.util.StringUtils; + +/** + * @author Seunghyeon Ji + */ +@SpringBootConfiguration +public class SolarTestConfiguration { + + @Bean + public SolarApi solarApi() { + return new SolarApi(getApiKey()); + } + + private String getApiKey() { + String apiKey = System.getenv("SOLAR_API_KEY"); + if (!StringUtils.hasText(apiKey)) { + throw new IllegalArgumentException( + "You must provide an API key. Put it in an environment variable under the name SOLAR_API_KEY"); + } + return apiKey; + } + + @Bean + public SolarChatModel solarChatModel(SolarApi api) { + return new SolarChatModel(api); + } + + @Bean + public EmbeddingModel solarEmbeddingModel(SolarApi api) { + return new SolarEmbeddingModel(api); + } + +} diff --git a/models/spring-ai-solar/src/test/java/embedding/SolarEmbeddingIT.java b/models/spring-ai-solar/src/test/java/embedding/SolarEmbeddingIT.java new file mode 100644 index 00000000000..a249e0bb87b --- /dev/null +++ b/models/spring-ai-solar/src/test/java/embedding/SolarEmbeddingIT.java @@ -0,0 +1,75 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package embedding; + +import static org.assertj.core.api.Assertions.*; + +import java.util.List; + +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariables; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; + +import chat.SolarTestConfiguration; + +/** + * @author Seunghyeon Ji + */ +@SpringBootTest(classes = SolarTestConfiguration.class) +@EnabledIfEnvironmentVariables({ @EnabledIfEnvironmentVariable(named = "SOLAR_API_KEY", matches = ".+") }) +class SolarEmbeddingIT { + + @Autowired + private EmbeddingModel embeddingModel; + + @Test + void defaultEmbedding() { + Assertions.assertThat(this.embeddingModel).isNotNull(); + + EmbeddingResponse embeddingResponse = this.embeddingModel.embedForResponse(List.of("Hello World")); + + assertThat(embeddingResponse.getResults()).hasSize(1); + + assertThat(embeddingResponse.getResults().get(0)).isNotNull(); + assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(4096); + + Assertions.assertThat(this.embeddingModel.dimensions()).isEqualTo(4096); + } + + @Test + void batchEmbedding() { + Assertions.assertThat(this.embeddingModel).isNotNull(); + + EmbeddingResponse embeddingResponse = this.embeddingModel.embedForResponse(List.of("Hello World", "HI")); + + assertThat(embeddingResponse.getResults()).hasSize(2); + + assertThat(embeddingResponse.getResults().get(0)).isNotNull(); + assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(4096); + + assertThat(embeddingResponse.getResults().get(1)).isNotNull(); + assertThat(embeddingResponse.getResults().get(1).getOutput()).hasSize(4096); + + Assertions.assertThat(this.embeddingModel.dimensions()).isEqualTo(4096); + } + +} diff --git a/models/spring-ai-solar/src/test/java/embedding/SolarEmbeddingModelObservationIT.java b/models/spring-ai-solar/src/test/java/embedding/SolarEmbeddingModelObservationIT.java new file mode 100644 index 00000000000..9958b39d29f --- /dev/null +++ b/models/spring-ai-solar/src/test/java/embedding/SolarEmbeddingModelObservationIT.java @@ -0,0 +1,113 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package embedding; + +import static org.assertj.core.api.Assertions.*; +import static org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation.*; + +import java.util.List; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariables; +import org.springframework.ai.document.MetadataMode; +import org.springframework.ai.embedding.EmbeddingRequest; +import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.ai.embedding.EmbeddingResponseMetadata; +import org.springframework.ai.embedding.observation.DefaultEmbeddingModelObservationConvention; +import org.springframework.ai.observation.conventions.AiOperationType; +import org.springframework.ai.observation.conventions.AiProvider; +import org.springframework.ai.solar.SolarEmbeddingModel; +import org.springframework.ai.solar.SolarEmbeddingOptions; +import org.springframework.ai.solar.api.SolarApi; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; +import org.springframework.retry.support.RetryTemplate; + +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; + +/** + * Integration tests for observation instrumentation in {@link SolarEmbeddingModel}. + * + * @author Seunghyeon Ji + */ +@SpringBootTest(classes = SolarEmbeddingModelObservationIT.Config.class) +@EnabledIfEnvironmentVariables({ @EnabledIfEnvironmentVariable(named = "SOLAR_API_KEY", matches = ".+") }) +public class SolarEmbeddingModelObservationIT { + + @Autowired + TestObservationRegistry observationRegistry; + + @Autowired + SolarEmbeddingModel embeddingModel; + + @Test + void observationForEmbeddingOperation() { + var options = SolarEmbeddingOptions.builder().withModel(SolarApi.DEFAULT_EMBEDDING_MODEL).build(); + + EmbeddingRequest embeddingRequest = new EmbeddingRequest(List.of("Here comes the sun"), options); + + EmbeddingResponse embeddingResponse = this.embeddingModel.call(embeddingRequest); + assertThat(embeddingResponse.getResults()).isNotEmpty(); + + EmbeddingResponseMetadata responseMetadata = embeddingResponse.getMetadata(); + assertThat(responseMetadata).isNotNull(); + + TestObservationRegistryAssert.assertThat(this.observationRegistry) + .doesNotHaveAnyRemainingCurrentObservation() + .hasObservationWithNameEqualTo(DefaultEmbeddingModelObservationConvention.DEFAULT_NAME) + .that() + .hasContextualNameEqualTo("embedding " + SolarApi.DEFAULT_EMBEDDING_MODEL) + .hasLowCardinalityKeyValue(LowCardinalityKeyNames.AI_OPERATION_TYPE.asString(), + AiOperationType.EMBEDDING.value()) + .hasLowCardinalityKeyValue(LowCardinalityKeyNames.AI_PROVIDER.asString(), AiProvider.SOLAR.value()) + .hasLowCardinalityKeyValue(LowCardinalityKeyNames.REQUEST_MODEL.asString(), + SolarApi.DEFAULT_EMBEDDING_MODEL) + .hasLowCardinalityKeyValue(LowCardinalityKeyNames.RESPONSE_MODEL.asString(), responseMetadata.getModel()) + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_INPUT_TOKENS.asString(), + String.valueOf(responseMetadata.getUsage().getPromptTokens())) + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_TOTAL_TOKENS.asString(), + String.valueOf(responseMetadata.getUsage().getTotalTokens())) + .hasBeenStarted() + .hasBeenStopped(); + } + + @SpringBootConfiguration + static class Config { + + @Bean + public TestObservationRegistry observationRegistry() { + return TestObservationRegistry.create(); + } + + @Bean + public SolarApi solarApi() { + return new SolarApi(System.getenv("SOLAR_API_KEY")); + } + + @Bean + public SolarEmbeddingModel solarEmbeddingModel(SolarApi solarApi, TestObservationRegistry observationRegistry) { + return new SolarEmbeddingModel(solarApi, MetadataMode.EMBED, SolarEmbeddingOptions.builder().build(), + RetryTemplate.defaultInstance(), observationRegistry); + } + + } + +} diff --git a/models/spring-ai-solar/src/test/resources/prompts/system-message.st b/models/spring-ai-solar/src/test/resources/prompts/system-message.st new file mode 100644 index 00000000000..d946c71b6d6 --- /dev/null +++ b/models/spring-ai-solar/src/test/resources/prompts/system-message.st @@ -0,0 +1,3 @@ +You are an AI assistant that helps people find information. +Your name is {name}. +You should reply to the user's request with your name and also in the style of a {voice}. diff --git a/pom.xml b/pom.xml index 6af6e0d0c94..b3602dcd5b3 100644 --- a/pom.xml +++ b/pom.xml @@ -96,6 +96,7 @@ models/spring-ai-openai models/spring-ai-postgresml models/spring-ai-qianfan + models/spring-ai-solar models/spring-ai-stability-ai models/spring-ai-transformers models/spring-ai-vertex-ai-embedding @@ -404,7 +405,7 @@ plain - + false diff --git a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java index d6b14c5a70b..7a5c3279d2b 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java @@ -44,7 +44,8 @@ public enum AiProvider { SPRING_AI("spring_ai"), VERTEX_AI("vertex_ai"), BEDROCK_CONVERSE("bedrock_converse"), - ONNX("onnx"); + ONNX("onnx"), + SOLAR("solar"); private final String value;