diff --git a/models/spring-ai-azure-openai/README_AZURE_OPENAI.md b/models/spring-ai-azure-openai/README_AZURE_OPENAI.md
index 888e6c2144a..51ec3373197 100644
--- a/models/spring-ai-azure-openai/README_AZURE_OPENAI.md
+++ b/models/spring-ai-azure-openai/README_AZURE_OPENAI.md
@@ -1,78 +1,9 @@
-# 1. Azure OpenAI
+# Azure OpenAI
Provides Azure OpenAI Chat and Embedding clients.
Leverages the native [OpenAIClient](https://learn.microsoft.com/en-us/java/api/overview/azure/ai-openai-readme?view=azure-java-preview#streaming-chat-completions) to interact with the [Amazon AI Studio models and deployment](https://oai.azure.com/).
-## 1.1 Prerequisites
+Find additional information:
-1. Azure Subscription: You will need an [Azure subscription](https://azure.microsoft.com/en-us/free/) to use any Azure service.
-2. Azure AI, Azure OpenAI Service: Create [Azure OpenAI](https://portal.azure.com/#create/Microsoft.CognitiveServicesOpenAI).
-Once the service is created, obtain the endpoint and apiKey from the `Keys and Endpoint` section under `Resource Management`.
-3. Use the [Azure Ai Studio](https://oai.azure.com/portal) to deploy the models you are going to use.
-
-## 1.2 AzureOpenAiChatClient
-
-[AzureOpenAiChatClient](./src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatClient.java) implements the Spring-Ai `ChatClient` and `StreamingChatClient` on top of the `OpenAIClient`.
-
-[AzureOpenAiEmbeddingClient](./src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingClient.java) implements the Spring-Ai `EmbeddingClient` on top of the `OpenAIClient`.
-
-
-You can configure the AzureOpenAiChatClient and AzureOpenAiEmbeddingClientlike this:
-
-```java
-@Bean
-public OpenAIClient openAIClient() {
- return new OpenAIClientBuilder()
- .credential(new AzureKeyCredential({YOUR_AZURE_OPENAI_API_KEY}))
- .endpoint({YOUR_AZURE_OPENAI_ENDPOINT})
- .buildClient();
-}
-
-@Bean
-public AzureOpenAiChatClient cohereChatClient(OpenAIClient openAIClient) {
- return new AzureOpenAiChatClient(openAIClient)
- .withModel("gpt-35-turbo")
- .withMaxTokens(200)
- .withTemperature(0.8);
-}
-
-@Bean
-public AzureOpenAiEmbeddingClient cohereEmbeddingClient(OpenAIClient openAIClient) {
- return new AzureOpenAiEmbeddingClient(openAIClient, "text-embedding-ada-002-v1");
-}
-```
-
-## 1.3 Azure OpenAi Auto-Configuration and Spring Boot Starter
-
-You can leverage the `spring-ai-azure-openai-spring-boot-starter` Boot starter.
-For this add the following dependency:
-
-```xml
-
- spring-ai-azure-openai-spring-boot-starter
- org.springframework.ai
- 0.8.0-SNAPSHOT
-
-```
-
-Use the `AzureOpenAiConnectionProperties` to configure the Azure OpenAI access:
-
-| Property | Description | Default |
-| ------------- | ------------- | ------------- |
-| spring.ai.azure.openai.apiKey | Azure AI Open AI credentials api key. | From the Azure AI OpenAI `Keys and Endpoint` section under `Resource Management` |
-| spring.ai.azure.openai.endpoint | Azure AI Open AI endpoint. | From the Azure AI OpenAI `Keys and Endpoint` section under `Resource Management` |
-
-Use the `AzureOpenAiChatProperties` to configure the Chat client:
-
-| Property | Description | Default |
-| ------------- | ------------- | ------------- |
-| spring.ai.azure.openai.chat.model | The model id to use. | gpt-35-turbo |
-| spring.ai.azure.openai.chat.temperature | Controls the randomness of the output. Values can range over [0.0,1.0] | 0.7 |
-| spring.ai.azure.openai.chat.topP | An alternative to sampling with temperature called nucleus sampling. | |
-| spring.ai.azure.openai.chat.maxTokens | The maximum number of tokens to generate. | |
-
-Use the `AzureOpenAiEmbeddingProperties` to configure the Embedding client:
-
-| Property | Description | Default |
-| ------------- | ------------- | ------------- |
-| spring.ai.azure.openai.embedding.model | The model id to use for embedding | text-embedding-ada-002 |
+- [Azure OpenAi Chat Client](https://docs.spring.io/spring-ai/reference/api/clients/azure-openai-chat.html)
+- [Azure OpenAi Embeddings Client](https://docs.spring.io/spring-ai/reference/api/embeddings/azure-openai-embeddings.html)
diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatClient.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatClient.java
index 1dc941c7759..2699da57b9b 100644
--- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatClient.java
+++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatClient.java
@@ -1,5 +1,5 @@
/*
- * Copyright 2023 the original author or authors.
+ * Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -27,23 +27,24 @@
import com.azure.ai.openai.models.ChatRequestMessage;
import com.azure.ai.openai.models.ChatRequestSystemMessage;
import com.azure.ai.openai.models.ChatRequestUserMessage;
-import com.azure.ai.openai.models.ChatResponseMessage;
import com.azure.ai.openai.models.ContentFilterResultsForPrompt;
import com.azure.core.util.IterableStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import reactor.core.publisher.Flux;
import org.springframework.ai.azure.openai.metadata.AzureOpenAiChatResponseMetadata;
import org.springframework.ai.chat.ChatClient;
+import org.springframework.ai.chat.ChatOptions;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.StreamingChatClient;
+import org.springframework.ai.chat.messages.Message;
+import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.PromptMetadata;
import org.springframework.ai.chat.metadata.PromptMetadata.PromptFilterMetadata;
import org.springframework.ai.chat.prompt.Prompt;
-import org.springframework.ai.chat.messages.Message;
+import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.util.Assert;
/**
@@ -59,104 +60,39 @@
*/
public class AzureOpenAiChatClient implements ChatClient, StreamingChatClient {
- /**
- * The sampling temperature to use that controls the apparent creativity of generated
- * completions. Higher values will make output more random while lower values will
- * make results more focused and deterministic. It is not recommended to modify
- * temperature and top_p for the same completions request as the interaction of these
- * two settings is difficult to predict.
- */
- private Double temperature = 0.7;
+ private static final String DEFAULT_MODEL = "gpt-35-turbo";
- /**
- * An alternative to sampling with temperature called nucleus sampling. This value
- * causes the model to consider the results of tokens with the provided probability
- * mass. As an example, a value of 0.15 will cause only the tokens comprising the top
- * 15% of probability mass to be considered. It is not recommended to modify
- * temperature and top_p for the same completions request as the interaction of these
- * two settings is difficult to predict.
- */
- private Double topP;
+ private static final Float DEFAULT_TEMPERATURE = 0.7f;
+
+ private final Logger logger = LoggerFactory.getLogger(getClass());
/**
- * Creates an instance of ChatCompletionsOptions class.
+ * The configuration information for a chat completions request.
*/
- private String model = "gpt-35-turbo";
+ private AzureOpenAiChatOptions defaultOptions;
/**
- * The maximum number of tokens to generate.
+ * The {@link OpenAIClient} used to interact with the Azure OpenAI service.
*/
- private Integer maxTokens;
-
- private final Logger logger = LoggerFactory.getLogger(getClass());
-
private final OpenAIClient openAIClient;
public AzureOpenAiChatClient(OpenAIClient microsoftOpenAiClient) {
Assert.notNull(microsoftOpenAiClient, "com.azure.ai.openai.OpenAIClient must not be null");
this.openAIClient = microsoftOpenAiClient;
+ this.defaultOptions = AzureOpenAiChatOptions.builder()
+ .withModel(DEFAULT_MODEL)
+ .withTemperature(DEFAULT_TEMPERATURE)
+ .build();
}
- public String getModel() {
- return this.model;
- }
-
- public AzureOpenAiChatClient withModel(String model) {
- this.model = model;
- return this;
- }
-
- public Double getTemperature() {
- return this.temperature;
- }
-
- public AzureOpenAiChatClient withTemperature(Double temperature) {
- this.temperature = temperature;
- return this;
- }
-
- public Double getTopP() {
- return topP;
- }
-
- public AzureOpenAiChatClient withTopP(Double topP) {
- this.topP = topP;
+ public AzureOpenAiChatClient withDefaultOptions(AzureOpenAiChatOptions defaultOptions) {
+ Assert.notNull(defaultOptions, "DefaultOptions must not be null");
+ this.defaultOptions = defaultOptions;
return this;
}
- public Integer getMaxTokens() {
- return maxTokens;
- }
-
- public AzureOpenAiChatClient withMaxTokens(Integer maxTokens) {
- this.maxTokens = maxTokens;
- return this;
- }
-
- @Override
- public String call(String text) {
-
- ChatRequestMessage azureChatMessage = new ChatRequestUserMessage(text);
-
- ChatCompletionsOptions options = new ChatCompletionsOptions(List.of(azureChatMessage));
- options.setTemperature(this.getTemperature());
- options.setModel(this.getModel());
-
- logger.trace("Azure Chat Message: {}", azureChatMessage);
-
- ChatCompletions chatCompletions = this.openAIClient.getChatCompletions(this.getModel(), options);
- logger.trace("Azure ChatCompletions: {}", chatCompletions);
-
- StringBuilder stringBuilder = new StringBuilder();
-
- for (ChatChoice choice : chatCompletions.getChoices()) {
- ChatResponseMessage message = choice.getMessage();
- if (message != null && message.getContent() != null) {
- stringBuilder.append(message.getContent());
- }
- }
-
- return stringBuilder.toString();
+ public AzureOpenAiChatOptions getDefaultOptions() {
+ return defaultOptions;
}
@Override
@@ -167,7 +103,7 @@ public ChatResponse call(Prompt prompt) {
logger.trace("Azure ChatCompletionsOptions: {}", options);
- ChatCompletions chatCompletions = this.openAIClient.getChatCompletions(this.getModel(), options);
+ ChatCompletions chatCompletions = this.openAIClient.getChatCompletions(options.getModel(), options);
logger.trace("Azure ChatCompletions: {}", chatCompletions);
@@ -178,6 +114,7 @@ public ChatResponse call(Prompt prompt) {
.toList();
PromptMetadata promptFilterMetadata = generatePromptMetadata(chatCompletions);
+
return new ChatResponse(generations,
AzureOpenAiChatResponseMetadata.from(chatCompletions, promptFilterMetadata));
}
@@ -189,12 +126,11 @@ public Flux stream(Prompt prompt) {
options.setStream(true);
IterableStream chatCompletionsStream = this.openAIClient
- .getChatCompletionsStream(this.getModel(), options);
+ .getChatCompletionsStream(options.getModel(), options);
return Flux.fromStream(chatCompletionsStream.stream()
// Note: the first chat completions can be ignored when using Azure OpenAI
- // service which is a
- // known service bug.
+ // service which is a known service bug.
.skip(1)
.map(ChatCompletions::getChoices)
.flatMap(List::stream)
@@ -205,7 +141,10 @@ public Flux stream(Prompt prompt) {
}));
}
- private ChatCompletionsOptions toAzureChatCompletionsOptions(Prompt prompt) {
+ /**
+ * Test access.
+ */
+ ChatCompletionsOptions toAzureChatCompletionsOptions(Prompt prompt) {
List azureMessages = prompt.getInstructions()
.stream()
@@ -214,10 +153,29 @@ private ChatCompletionsOptions toAzureChatCompletionsOptions(Prompt prompt) {
ChatCompletionsOptions options = new ChatCompletionsOptions(azureMessages);
- options.setTemperature(this.getTemperature());
- options.setModel(this.getModel());
- options.setTopP(this.getTopP());
- options.setMaxTokens(this.getMaxTokens());
+ if (this.defaultOptions != null) {
+ // JSON merge doesn't due to Azure OpenAI service bug:
+ // https://github.com/Azure/azure-sdk-for-java/issues/38183
+ // options = ModelOptionsUtils.merge(options, this.defaultOptions,
+ // ChatCompletionsOptions.class);
+ options = merge(options, this.defaultOptions);
+ }
+
+ if (prompt.getOptions() != null) {
+ if (prompt.getOptions() instanceof ChatOptions runtimeOptions) {
+ AzureOpenAiChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions,
+ ChatOptions.class, AzureOpenAiChatOptions.class);
+ // JSON merge doesn't due to Azure OpenAI service bug:
+ // https://github.com/Azure/azure-sdk-for-java/issues/38183
+ // options = ModelOptionsUtils.merge(runtimeOptions, options,
+ // ChatCompletionsOptions.class);
+ options = merge(updatedRuntimeOptions, options);
+ }
+ else {
+ throw new IllegalArgumentException("Prompt options are not of type ChatCompletionsOptions:"
+ + prompt.getOptions().getClass().getSimpleName());
+ }
+ }
return options;
}
@@ -256,4 +214,121 @@ private List nullSafeList(List list) {
return list != null ? list : Collections.emptyList();
}
+ // JSON merge doesn't due to Azure OpenAI service bug:
+ // https://github.com/Azure/azure-sdk-for-java/issues/38183
+ private ChatCompletionsOptions merge(ChatCompletionsOptions azureOptions, AzureOpenAiChatOptions springAiOptions) {
+
+ if (springAiOptions == null) {
+ return azureOptions;
+ }
+
+ ChatCompletionsOptions mergedAzureOptions = new ChatCompletionsOptions(azureOptions.getMessages());
+ mergedAzureOptions.setStream(azureOptions.isStream());
+
+ mergedAzureOptions.setMaxTokens(azureOptions.getMaxTokens());
+ if (mergedAzureOptions.getMaxTokens() == null) {
+ mergedAzureOptions.setMaxTokens(springAiOptions.getMaxTokens());
+ }
+
+ mergedAzureOptions.setLogitBias(azureOptions.getLogitBias());
+ if (mergedAzureOptions.getLogitBias() == null) {
+ mergedAzureOptions.setLogitBias(springAiOptions.getLogitBias());
+ }
+
+ mergedAzureOptions.setStop(azureOptions.getStop());
+ if (mergedAzureOptions.getStop() == null) {
+ mergedAzureOptions.setStop(springAiOptions.getStop());
+ }
+
+ mergedAzureOptions.setTemperature(azureOptions.getTemperature());
+ if (mergedAzureOptions.getTemperature() == null && springAiOptions.getTemperature() != null) {
+ mergedAzureOptions.setTemperature(springAiOptions.getTemperature().doubleValue());
+ }
+
+ mergedAzureOptions.setTopP(azureOptions.getTopP());
+ if (mergedAzureOptions.getTopP() == null && springAiOptions.getTopP() != null) {
+ mergedAzureOptions.setTopP(springAiOptions.getTopP().doubleValue());
+ }
+
+ mergedAzureOptions.setFrequencyPenalty(azureOptions.getFrequencyPenalty());
+ if (mergedAzureOptions.getFrequencyPenalty() == null && springAiOptions.getFrequencyPenalty() != null) {
+ mergedAzureOptions.setFrequencyPenalty(springAiOptions.getFrequencyPenalty().doubleValue());
+ }
+
+ mergedAzureOptions.setPresencePenalty(azureOptions.getPresencePenalty());
+ if (mergedAzureOptions.getPresencePenalty() == null && springAiOptions.getPresencePenalty() != null) {
+ mergedAzureOptions.setPresencePenalty(springAiOptions.getPresencePenalty().doubleValue());
+ }
+
+ mergedAzureOptions.setN(azureOptions.getN());
+ if (mergedAzureOptions.getN() == null) {
+ mergedAzureOptions.setN(springAiOptions.getN());
+ }
+
+ mergedAzureOptions.setUser(azureOptions.getUser());
+ if (mergedAzureOptions.getUser() == null) {
+ mergedAzureOptions.setUser(springAiOptions.getUser());
+ }
+
+ mergedAzureOptions.setModel(azureOptions.getModel());
+ if (mergedAzureOptions.getModel() == null) {
+ mergedAzureOptions.setModel(springAiOptions.getModel());
+ }
+
+ return mergedAzureOptions;
+ }
+
+ // JSON merge doesn't due to Azure OpenAI service bug:
+ // https://github.com/Azure/azure-sdk-for-java/issues/38183
+ private ChatCompletionsOptions merge(AzureOpenAiChatOptions springAiOptions, ChatCompletionsOptions azureOptions) {
+ if (springAiOptions == null) {
+ return azureOptions;
+ }
+
+ ChatCompletionsOptions mergedAzureOptions = new ChatCompletionsOptions(azureOptions.getMessages());
+ mergedAzureOptions.setStream(azureOptions.isStream());
+
+ if (springAiOptions.getMaxTokens() != null) {
+ mergedAzureOptions.setMaxTokens(springAiOptions.getMaxTokens());
+ }
+
+ if (springAiOptions.getLogitBias() != null) {
+ mergedAzureOptions.setLogitBias(springAiOptions.getLogitBias());
+ }
+
+ if (springAiOptions.getStop() != null) {
+ mergedAzureOptions.setStop(springAiOptions.getStop());
+ }
+
+ if (springAiOptions.getTemperature() != null && springAiOptions.getTemperature() != null) {
+ mergedAzureOptions.setTemperature(springAiOptions.getTemperature().doubleValue());
+ }
+
+ if (springAiOptions.getTopP() != null && springAiOptions.getTopP() != null) {
+ mergedAzureOptions.setTopP(springAiOptions.getTopP().doubleValue());
+ }
+
+ if (springAiOptions.getFrequencyPenalty() != null && springAiOptions.getFrequencyPenalty() != null) {
+ mergedAzureOptions.setFrequencyPenalty(springAiOptions.getFrequencyPenalty().doubleValue());
+ }
+
+ if (springAiOptions.getPresencePenalty() != null && springAiOptions.getPresencePenalty() != null) {
+ mergedAzureOptions.setPresencePenalty(springAiOptions.getPresencePenalty().doubleValue());
+ }
+
+ if (springAiOptions.getN() != null) {
+ mergedAzureOptions.setN(springAiOptions.getN());
+ }
+
+ if (springAiOptions.getUser() != null) {
+ mergedAzureOptions.setUser(springAiOptions.getUser());
+ }
+
+ if (springAiOptions.getModel() != null) {
+ mergedAzureOptions.setModel(springAiOptions.getModel());
+ }
+
+ return mergedAzureOptions;
+ }
+
}
\ No newline at end of file
diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java
new file mode 100644
index 00000000000..4b1878def89
--- /dev/null
+++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java
@@ -0,0 +1,292 @@
+/*
+ * Copyright 2024-2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.ai.azure.openai;
+
+import java.util.List;
+import java.util.Map;
+
+import com.fasterxml.jackson.annotation.JsonIgnore;
+import com.fasterxml.jackson.annotation.JsonInclude;
+import com.fasterxml.jackson.annotation.JsonInclude.Include;
+import com.fasterxml.jackson.annotation.JsonProperty;
+
+import org.springframework.ai.chat.ChatOptions;
+
+/**
+ * The configuration information for a chat completions request. Completions support a
+ * wide variety of tasks and generate text that continues from or "completes" provided
+ * prompt data.
+ *
+ * @author Christian Tzolov
+ */
+@JsonInclude(Include.NON_NULL)
+public class AzureOpenAiChatOptions implements ChatOptions {
+
+ /**
+ * The maximum number of tokens to generate.
+ */
+ @JsonProperty(value = "max_tokens")
+ private Integer maxTokens;
+
+ /**
+ * The sampling temperature to use that controls the apparent creativity of generated
+ * completions. Higher values will make output more random while lower values will
+ * make results more focused and deterministic. It is not recommended to modify
+ * temperature and top_p for the same completions request as the interaction of these
+ * two settings is difficult to predict.
+ */
+ @JsonProperty(value = "temperature")
+ private Float temperature;
+
+ /**
+ * An alternative to sampling with temperature called nucleus sampling. This value
+ * causes the model to consider the results of tokens with the provided probability
+ * mass. As an example, a value of 0.15 will cause only the tokens comprising the top
+ * 15% of probability mass to be considered. It is not recommended to modify
+ * temperature and top_p for the same completions request as the interaction of these
+ * two settings is difficult to predict.
+ */
+ @JsonProperty(value = "top_p")
+ private Float topP;
+
+ /**
+ * A map between GPT token IDs and bias scores that influences the probability of
+ * specific tokens appearing in a completions response. Token IDs are computed via
+ * external tokenizer tools, while bias scores reside in the range of -100 to 100 with
+ * minimum and maximum values corresponding to a full ban or exclusive selection of a
+ * token, respectively. The exact behavior of a given bias score varies by model.
+ */
+ @JsonProperty(value = "logit_bias")
+ private Map logitBias;
+
+ /**
+ * An identifier for the caller or end user of the operation. This may be used for
+ * tracking or rate-limiting purposes.
+ */
+ @JsonProperty(value = "user")
+ private String user;
+
+ /**
+ * The number of chat completions choices that should be generated for a chat
+ * completions response. Because this setting can generate many completions, it may
+ * quickly consume your token quota. Use carefully and ensure reasonable settings for
+ * max_tokens and stop.
+ */
+ @JsonProperty(value = "n")
+ private Integer n;
+
+ /**
+ * A collection of textual sequences that will end completions generation.
+ */
+ @JsonProperty(value = "stop")
+ private List stop;
+
+ /**
+ * A value that influences the probability of generated tokens appearing based on
+ * their existing presence in generated text. Positive values will make tokens less
+ * likely to appear when they already exist and increase the model's likelihood to
+ * output new topics.
+ */
+ @JsonProperty(value = "presence_penalty")
+ private Double presencePenalty;
+
+ /**
+ * A value that influences the probability of generated tokens appearing based on
+ * their cumulative frequency in generated text. Positive values will make tokens less
+ * likely to appear as their frequency increases and decrease the likelihood of the
+ * model repeating the same statements verbatim.
+ */
+ @JsonProperty(value = "frequency_penalty")
+ private Double frequencyPenalty;
+
+ /**
+ * The model name to provide as part of this completions request. Not applicable to
+ * Azure OpenAI, where deployment information should be included in the Azure resource
+ * URI that's connected to.
+ */
+ @JsonProperty(value = "model")
+ private String model;
+
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ public static class Builder {
+
+ protected AzureOpenAiChatOptions options;
+
+ public Builder() {
+ this.options = new AzureOpenAiChatOptions();
+ }
+
+ public Builder(AzureOpenAiChatOptions options) {
+ this.options = options;
+ }
+
+ public Builder withModel(String model) {
+ this.options.model = model;
+ return this;
+ }
+
+ public Builder withFrequencyPenalty(Float frequencyPenalty) {
+ this.options.frequencyPenalty = frequencyPenalty.doubleValue();
+ return this;
+ }
+
+ public Builder withLogitBias(Map logitBias) {
+ this.options.logitBias = logitBias;
+ return this;
+ }
+
+ public Builder withMaxTokens(Integer maxTokens) {
+ this.options.maxTokens = maxTokens;
+ return this;
+ }
+
+ public Builder withN(Integer n) {
+ this.options.n = n;
+ return this;
+ }
+
+ public Builder withPresencePenalty(Float presencePenalty) {
+ this.options.presencePenalty = presencePenalty.doubleValue();
+ return this;
+ }
+
+ public Builder withStop(List stop) {
+ this.options.stop = stop;
+ return this;
+ }
+
+ public Builder withTemperature(Float temperature) {
+ this.options.temperature = temperature;
+ return this;
+ }
+
+ public Builder withTopP(Float topP) {
+ this.options.topP = topP;
+ return this;
+ }
+
+ public Builder withUser(String user) {
+ this.options.user = user;
+ return this;
+ }
+
+ public AzureOpenAiChatOptions build() {
+ return this.options;
+ }
+
+ }
+
+ public Integer getMaxTokens() {
+ return this.maxTokens;
+ }
+
+ public void setMaxTokens(Integer maxTokens) {
+ this.maxTokens = maxTokens;
+ }
+
+ public Map getLogitBias() {
+ return this.logitBias;
+ }
+
+ public void setLogitBias(Map logitBias) {
+ this.logitBias = logitBias;
+ }
+
+ public String getUser() {
+ return this.user;
+ }
+
+ public void setUser(String user) {
+ this.user = user;
+ }
+
+ public Integer getN() {
+ return this.n;
+ }
+
+ public void setN(Integer n) {
+ this.n = n;
+ }
+
+ public List getStop() {
+ return this.stop;
+ }
+
+ public void setStop(List stop) {
+ this.stop = stop;
+ }
+
+ public Double getPresencePenalty() {
+ return this.presencePenalty;
+ }
+
+ public void setPresencePenalty(Double presencePenalty) {
+ this.presencePenalty = presencePenalty;
+ }
+
+ public Double getFrequencyPenalty() {
+ return this.frequencyPenalty;
+ }
+
+ public void setFrequencyPenalty(Double frequencyPenalty) {
+ this.frequencyPenalty = frequencyPenalty;
+ }
+
+ public String getModel() {
+ return this.model;
+ }
+
+ public void setModel(String model) {
+ this.model = model;
+ }
+
+ @Override
+ public Float getTemperature() {
+ return this.temperature;
+ }
+
+ @Override
+ public void setTemperature(Float temperature) {
+ this.temperature = temperature;
+ }
+
+ @Override
+ public Float getTopP() {
+ return this.topP;
+ }
+
+ @Override
+ public void setTopP(Float topP) {
+ this.topP = topP;
+ }
+
+ @Override
+ @JsonIgnore
+ public Integer getTopK() {
+ throw new UnsupportedOperationException("Unimplemented method 'getTopK'");
+ }
+
+ @Override
+ @JsonIgnore
+ public void setTopK(Integer topK) {
+ throw new UnsupportedOperationException("Unimplemented method 'setTopK'");
+ }
+
+}
diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingClient.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingClient.java
index 6e217f5e5a2..76bfc77f19c 100644
--- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingClient.java
+++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingClient.java
@@ -18,6 +18,7 @@
import org.springframework.ai.embedding.EmbeddingRequest;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
+import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.util.Assert;
public class AzureOpenAiEmbeddingClient extends AbstractEmbeddingClient {
@@ -26,52 +27,69 @@ public class AzureOpenAiEmbeddingClient extends AbstractEmbeddingClient {
private final OpenAIClient azureOpenAiClient;
- private final String model;
+ private AzureOpenAiEmbeddingOptions defaultOptions = AzureOpenAiEmbeddingOptions.builder()
+ .withModel("text-embedding-ada-002")
+ .build();
private final MetadataMode metadataMode;
public AzureOpenAiEmbeddingClient(OpenAIClient azureOpenAiClient) {
- this(azureOpenAiClient, "text-embedding-ada-002");
+ this(azureOpenAiClient, MetadataMode.EMBED);
}
- public AzureOpenAiEmbeddingClient(OpenAIClient azureOpenAiClient, String model) {
- this(azureOpenAiClient, model, MetadataMode.EMBED);
- }
-
- public AzureOpenAiEmbeddingClient(OpenAIClient azureOpenAiClient, String model, MetadataMode metadataMode) {
+ public AzureOpenAiEmbeddingClient(OpenAIClient azureOpenAiClient, MetadataMode metadataMode) {
Assert.notNull(azureOpenAiClient, "com.azure.ai.openai.OpenAIClient must not be null");
- Assert.notNull(model, "Model must not be null");
Assert.notNull(metadataMode, "Metadata mode must not be null");
this.azureOpenAiClient = azureOpenAiClient;
- this.model = model;
this.metadataMode = metadataMode;
}
@Override
public List embed(Document document) {
logger.debug("Retrieving embeddings");
- Embeddings embeddings = this.azureOpenAiClient.getEmbeddings(this.model,
- new EmbeddingsOptions(List.of(document.getFormattedContent(this.metadataMode))));
- logger.debug("Embeddings retrieved");
- return extractEmbeddingsList(embeddings);
- }
- private List extractEmbeddingsList(Embeddings embeddings) {
- return embeddings.getData().stream().map(EmbeddingItem::getEmbedding).flatMap(List::stream).toList();
+ EmbeddingResponse response = this
+ .call(new EmbeddingRequest(List.of(document.getFormattedContent(this.metadataMode)), null));
+ logger.debug("Embeddings retrieved");
+ return response.getResults().stream().map(embedding -> embedding.getOutput()).flatMap(List::stream).toList();
}
@Override
- public EmbeddingResponse call(EmbeddingRequest request) {
+ public EmbeddingResponse call(EmbeddingRequest embeddingRequest) {
logger.debug("Retrieving embeddings");
- Embeddings embeddings = this.azureOpenAiClient.getEmbeddings(this.model,
- new EmbeddingsOptions(request.getInstructions()));
+
+ EmbeddingsOptions azureOptions = new EmbeddingsOptions(embeddingRequest.getInstructions());
+ if (this.defaultOptions != null) {
+ azureOptions = ModelOptionsUtils.merge(azureOptions, this.defaultOptions, EmbeddingsOptions.class);
+ }
+ if (embeddingRequest.getOptions() != null) {
+ azureOptions = ModelOptionsUtils.merge(embeddingRequest.getOptions(), azureOptions,
+ EmbeddingsOptions.class);
+ }
+ Embeddings embeddings = this.azureOpenAiClient.getEmbeddings(azureOptions.getModel(), azureOptions);
+
logger.debug("Embeddings retrieved");
return generateEmbeddingResponse(embeddings);
}
+ /**
+ * Test access
+ */
+ EmbeddingsOptions toEmbeddingOptions(EmbeddingRequest embeddingRequest) {
+ var azureOptions = new EmbeddingsOptions(embeddingRequest.getInstructions());
+ if (this.defaultOptions != null) {
+ azureOptions = ModelOptionsUtils.merge(azureOptions, this.defaultOptions, EmbeddingsOptions.class);
+ }
+ if (embeddingRequest.getOptions() != null) {
+ azureOptions = ModelOptionsUtils.merge(embeddingRequest.getOptions(), azureOptions,
+ EmbeddingsOptions.class);
+ }
+ return azureOptions;
+ }
+
private EmbeddingResponse generateEmbeddingResponse(Embeddings embeddings) {
List data = generateEmbeddingList(embeddings.getData());
- EmbeddingResponseMetadata metadata = generateMetadata(this.model, embeddings.getUsage());
+ EmbeddingResponseMetadata metadata = generateMetadata(embeddings.getUsage());
return new EmbeddingResponse(data, metadata);
}
@@ -86,12 +104,26 @@ private List generateEmbeddingList(List nativeData) {
return data;
}
- private EmbeddingResponseMetadata generateMetadata(String model, EmbeddingsUsage embeddingsUsage) {
+ private EmbeddingResponseMetadata generateMetadata(EmbeddingsUsage embeddingsUsage) {
EmbeddingResponseMetadata metadata = new EmbeddingResponseMetadata();
- metadata.put("model", model);
+ // metadata.put("model", model);
metadata.put("prompt-tokens", embeddingsUsage.getPromptTokens());
metadata.put("total-tokens", embeddingsUsage.getTotalTokens());
return metadata;
}
+ public AzureOpenAiEmbeddingOptions getDefaultOptions() {
+ return this.defaultOptions;
+ }
+
+ public void setDefaultOptions(AzureOpenAiEmbeddingOptions defaultOptions) {
+ Assert.notNull(defaultOptions, "Default options must not be null");
+ this.defaultOptions = defaultOptions;
+ }
+
+ public AzureOpenAiEmbeddingClient withDefaultOptions(AzureOpenAiEmbeddingOptions options) {
+ this.defaultOptions = options;
+ return this;
+ }
+
}
diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingOptions.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingOptions.java
new file mode 100644
index 00000000000..3d89c40ae01
--- /dev/null
+++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingOptions.java
@@ -0,0 +1,86 @@
+/*
+ * Copyright 2024-2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.ai.azure.openai;
+
+import com.fasterxml.jackson.annotation.JsonProperty;
+
+import org.springframework.ai.embedding.EmbeddingOptions;
+
+/**
+ * The configuration information for the embedding requests.
+ *
+ * @author Christian Tzolov
+ * @since 0.8.0
+ */
+public class AzureOpenAiEmbeddingOptions implements EmbeddingOptions {
+
+ /**
+ * An identifier for the caller or end user of the operation. This may be used for
+ * tracking or rate-limiting purposes.
+ */
+ @JsonProperty(value = "user")
+ private String user;
+
+ /**
+ * The model name to provide as part of this embeddings request. Not applicable to
+ * Azure OpenAI, where deployment information should be included in the Azure resource
+ * URI that's connected to.
+ */
+ @JsonProperty(value = "model")
+ private String model;
+
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ public static class Builder {
+
+ private final AzureOpenAiEmbeddingOptions options = new AzureOpenAiEmbeddingOptions();
+
+ public Builder withUser(String user) {
+ this.options.setUser(user);
+ return this;
+ }
+
+ public Builder withModel(String model) {
+ this.options.setModel(model);
+ return this;
+ }
+
+ public AzureOpenAiEmbeddingOptions build() {
+ return this.options;
+ }
+
+ }
+
+ public String getUser() {
+ return this.user;
+ }
+
+ public void setUser(String user) {
+ this.user = user;
+ }
+
+ public String getModel() {
+ return this.model;
+ }
+
+ public void setModel(String model) {
+ this.model = model;
+ }
+
+}
diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java
new file mode 100644
index 00000000000..0d8a5509b9e
--- /dev/null
+++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java
@@ -0,0 +1,55 @@
+/*
+ * Copyright 2024-2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.ai.azure.openai;
+
+import com.azure.ai.openai.OpenAIClient;
+import org.junit.jupiter.api.Test;
+import org.mockito.Mockito;
+
+import org.springframework.ai.chat.prompt.Prompt;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/**
+ * @author Christian Tzolov
+ */
+public class AzureChatCompletionsOptionsTests {
+
+ @Test
+ public void createRequestWithChatOptions() {
+
+ OpenAIClient mockClient = Mockito.mock(OpenAIClient.class);
+ var client = new AzureOpenAiChatClient(mockClient).withDefaultOptions(
+ AzureOpenAiChatOptions.builder().withModel("DEFAULT_MODEL").withTemperature(66.6f).build());
+
+ var requestOptions = client.toAzureChatCompletionsOptions(new Prompt("Test message content"));
+
+ assertThat(requestOptions.getMessages()).hasSize(1);
+
+ assertThat(requestOptions.getModel()).isEqualTo("DEFAULT_MODEL");
+ assertThat(requestOptions.getTemperature()).isEqualTo(66.6f);
+
+ requestOptions = client.toAzureChatCompletionsOptions(new Prompt("Test message content",
+ AzureOpenAiChatOptions.builder().withModel("PROMPT_MODEL").withTemperature(99.9f).build()));
+
+ assertThat(requestOptions.getMessages()).hasSize(1);
+
+ assertThat(requestOptions.getModel()).isEqualTo("PROMPT_MODEL");
+ assertThat(requestOptions.getTemperature()).isEqualTo(99.9f);
+ }
+
+}
diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureEmbeddingsOptionsTests.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureEmbeddingsOptionsTests.java
new file mode 100644
index 00000000000..0aaa0abcd4c
--- /dev/null
+++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureEmbeddingsOptionsTests.java
@@ -0,0 +1,58 @@
+/*
+ * Copyright 2024-2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.ai.azure.openai;
+
+import java.util.List;
+
+import com.azure.ai.openai.OpenAIClient;
+import org.junit.jupiter.api.Test;
+import org.mockito.Mockito;
+
+import org.springframework.ai.embedding.EmbeddingRequest;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/**
+ * @author Christian Tzolov
+ * @since 0.8.0
+ */
+public class AzureEmbeddingsOptionsTests {
+
+ @Test
+ public void createRequestWithChatOptions() {
+
+ OpenAIClient mockClient = Mockito.mock(OpenAIClient.class);
+ var client = new AzureOpenAiEmbeddingClient(mockClient).withDefaultOptions(
+ AzureOpenAiEmbeddingOptions.builder().withModel("DEFAULT_MODEL").withUser("USER_TEST").build());
+
+ var requestOptions = client.toEmbeddingOptions(new EmbeddingRequest(List.of("Test message content"), null));
+
+ assertThat(requestOptions.getInput()).hasSize(1);
+
+ assertThat(requestOptions.getModel()).isEqualTo("DEFAULT_MODEL");
+ assertThat(requestOptions.getUser()).isEqualTo("USER_TEST");
+
+ requestOptions = client.toEmbeddingOptions(new EmbeddingRequest(List.of("Test message content"),
+ AzureOpenAiEmbeddingOptions.builder().withModel("PROMPT_MODEL").withUser("PROMPT_USER").build()));
+
+ assertThat(requestOptions.getInput()).hasSize(1);
+
+ assertThat(requestOptions.getModel()).isEqualTo("PROMPT_MODEL");
+ assertThat(requestOptions.getUser()).isEqualTo("PROMPT_USER");
+ }
+
+}
diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatClientIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatClientIT.java
index a65df0c955d..0de353e6010 100644
--- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatClientIT.java
+++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatClientIT.java
@@ -180,7 +180,9 @@ public OpenAIClient openAIClient() {
@Bean
public AzureOpenAiChatClient azureOpenAiChatClient(OpenAIClient openAIClient) {
- return new AzureOpenAiChatClient(openAIClient).withModel("gpt-35-turbo").withMaxTokens(200);
+ return new AzureOpenAiChatClient(openAIClient).withDefaultOptions(
+ AzureOpenAiChatOptions.builder().withModel("gpt-35-turbo").withMaxTokens(200).build());
+
}
}
diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingClientIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingClientIT.java
index f23e2e6abbb..00f301f72c6 100644
--- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingClientIT.java
+++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingClientIT.java
@@ -60,7 +60,7 @@ public OpenAIClient openAIClient() {
@Bean
public AzureOpenAiEmbeddingClient azureEmbeddingClient(OpenAIClient openAIClient) {
- return new AzureOpenAiEmbeddingClient(openAIClient, "text-embedding-ada-002");
+ return new AzureOpenAiEmbeddingClient(openAIClient);
}
}
diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java
index ec800ae424e..8fe0e8feba3 100644
--- a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java
+++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java
@@ -130,7 +130,11 @@ public static Map objectToMap(Object source) {
try {
String json = OBJECT_MAPPER.writeValueAsString(source);
return OBJECT_MAPPER.readValue(json, new TypeReference