Skip to content

Commit 5405b63

Browse files
committed
Add Azure OpenAI Chat and Embedding Options
- Add AzureOpenAiChatOptions - Add default options field to AzureOpenAiChatClient. - Impl runtime (e.g. prompt) and default options on call. - Add options field to the AzureOpenAiChatProperties. - Add AzureOpenAiEmbeddingOptions - Add default options field to AzureOpenAiEmbeddingClient. - Impmlement runtime and default option merging on embedding request. - Add options field to AzureOpenAiEmbeddingProperties. - Add Unit and ITs. - Split the azure-openai.adoc into ./clients/azure-openai-chat.adoc and ./embeddings/azure-openai-embeddings.adoc. - Provide detailed explanation how to use the chat and embedding clients manually or via the auto-configuration.
1 parent 9b79359 commit 5405b63

File tree

17 files changed

+1136
-212
lines changed

17 files changed

+1136
-212
lines changed
Lines changed: 164 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023 the original author or authors.
2+
* Copyright 2023-2024 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -27,23 +27,22 @@
2727
import com.azure.ai.openai.models.ChatRequestMessage;
2828
import com.azure.ai.openai.models.ChatRequestSystemMessage;
2929
import com.azure.ai.openai.models.ChatRequestUserMessage;
30-
import com.azure.ai.openai.models.ChatResponseMessage;
3130
import com.azure.ai.openai.models.ContentFilterResultsForPrompt;
3231
import com.azure.core.util.IterableStream;
3332
import org.slf4j.Logger;
3433
import org.slf4j.LoggerFactory;
35-
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
3634
import reactor.core.publisher.Flux;
3735

3836
import org.springframework.ai.azure.openai.metadata.AzureOpenAiChatResponseMetadata;
3937
import org.springframework.ai.chat.ChatClient;
4038
import org.springframework.ai.chat.ChatResponse;
4139
import org.springframework.ai.chat.Generation;
4240
import org.springframework.ai.chat.StreamingChatClient;
41+
import org.springframework.ai.chat.messages.Message;
42+
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
4343
import org.springframework.ai.chat.metadata.PromptMetadata;
4444
import org.springframework.ai.chat.metadata.PromptMetadata.PromptFilterMetadata;
4545
import org.springframework.ai.chat.prompt.Prompt;
46-
import org.springframework.ai.chat.messages.Message;
4746
import org.springframework.util.Assert;
4847

4948
/**
@@ -59,104 +58,39 @@
5958
*/
6059
public class AzureOpenAiChatClient implements ChatClient, StreamingChatClient {
6160

62-
/**
63-
* The sampling temperature to use that controls the apparent creativity of generated
64-
* completions. Higher values will make output more random while lower values will
65-
* make results more focused and deterministic. It is not recommended to modify
66-
* temperature and top_p for the same completions request as the interaction of these
67-
* two settings is difficult to predict.
68-
*/
69-
private Double temperature = 0.7;
61+
private static final String DEFAULT_MODEL = "gpt-35-turbo";
7062

71-
/**
72-
* An alternative to sampling with temperature called nucleus sampling. This value
73-
* causes the model to consider the results of tokens with the provided probability
74-
* mass. As an example, a value of 0.15 will cause only the tokens comprising the top
75-
* 15% of probability mass to be considered. It is not recommended to modify
76-
* temperature and top_p for the same completions request as the interaction of these
77-
* two settings is difficult to predict.
78-
*/
79-
private Double topP;
63+
private static final Float DEFAULT_TEMPERATURE = 0.7f;
64+
65+
private final Logger logger = LoggerFactory.getLogger(getClass());
8066

8167
/**
82-
* Creates an instance of ChatCompletionsOptions class.
68+
* The configuration information for a chat completions request.
8369
*/
84-
private String model = "gpt-35-turbo";
70+
private AzureOpenAiChatOptions defaultOptions;
8571

8672
/**
87-
* The maximum number of tokens to generate.
73+
* The {@link OpenAIClient} used to interact with the Azure OpenAI service.
8874
*/
89-
private Integer maxTokens;
90-
91-
private final Logger logger = LoggerFactory.getLogger(getClass());
92-
9375
private final OpenAIClient openAIClient;
9476

9577
public AzureOpenAiChatClient(OpenAIClient microsoftOpenAiClient) {
9678
Assert.notNull(microsoftOpenAiClient, "com.azure.ai.openai.OpenAIClient must not be null");
9779
this.openAIClient = microsoftOpenAiClient;
80+
this.defaultOptions = AzureOpenAiChatOptions.builder()
81+
.withModel(DEFAULT_MODEL)
82+
.withTemperature(DEFAULT_TEMPERATURE)
83+
.build();
9884
}
9985

100-
public String getModel() {
101-
return this.model;
102-
}
103-
104-
public AzureOpenAiChatClient withModel(String model) {
105-
this.model = model;
106-
return this;
107-
}
108-
109-
public Double getTemperature() {
110-
return this.temperature;
111-
}
112-
113-
public AzureOpenAiChatClient withTemperature(Double temperature) {
114-
this.temperature = temperature;
115-
return this;
116-
}
117-
118-
public Double getTopP() {
119-
return topP;
120-
}
121-
122-
public AzureOpenAiChatClient withTopP(Double topP) {
123-
this.topP = topP;
86+
public AzureOpenAiChatClient withDefaultOptions(AzureOpenAiChatOptions defaultOptions) {
87+
Assert.notNull(defaultOptions, "DefaultOptions must not be null");
88+
this.defaultOptions = defaultOptions;
12489
return this;
12590
}
12691

127-
public Integer getMaxTokens() {
128-
return maxTokens;
129-
}
130-
131-
public AzureOpenAiChatClient withMaxTokens(Integer maxTokens) {
132-
this.maxTokens = maxTokens;
133-
return this;
134-
}
135-
136-
@Override
137-
public String call(String text) {
138-
139-
ChatRequestMessage azureChatMessage = new ChatRequestUserMessage(text);
140-
141-
ChatCompletionsOptions options = new ChatCompletionsOptions(List.of(azureChatMessage));
142-
options.setTemperature(this.getTemperature());
143-
options.setModel(this.getModel());
144-
145-
logger.trace("Azure Chat Message: {}", azureChatMessage);
146-
147-
ChatCompletions chatCompletions = this.openAIClient.getChatCompletions(this.getModel(), options);
148-
logger.trace("Azure ChatCompletions: {}", chatCompletions);
149-
150-
StringBuilder stringBuilder = new StringBuilder();
151-
152-
for (ChatChoice choice : chatCompletions.getChoices()) {
153-
ChatResponseMessage message = choice.getMessage();
154-
if (message != null && message.getContent() != null) {
155-
stringBuilder.append(message.getContent());
156-
}
157-
}
158-
159-
return stringBuilder.toString();
92+
public AzureOpenAiChatOptions getDefaultOptions() {
93+
return defaultOptions;
16094
}
16195

16296
@Override
@@ -167,7 +101,7 @@ public ChatResponse call(Prompt prompt) {
167101

168102
logger.trace("Azure ChatCompletionsOptions: {}", options);
169103

170-
ChatCompletions chatCompletions = this.openAIClient.getChatCompletions(this.getModel(), options);
104+
ChatCompletions chatCompletions = this.openAIClient.getChatCompletions(options.getModel(), options);
171105

172106
logger.trace("Azure ChatCompletions: {}", chatCompletions);
173107

@@ -178,6 +112,7 @@ public ChatResponse call(Prompt prompt) {
178112
.toList();
179113

180114
PromptMetadata promptFilterMetadata = generatePromptMetadata(chatCompletions);
115+
181116
return new ChatResponse(generations,
182117
AzureOpenAiChatResponseMetadata.from(chatCompletions, promptFilterMetadata));
183118
}
@@ -189,7 +124,7 @@ public Flux<ChatResponse> stream(Prompt prompt) {
189124
options.setStream(true);
190125

191126
IterableStream<ChatCompletions> chatCompletionsStream = this.openAIClient
192-
.getChatCompletionsStream(this.getModel(), options);
127+
.getChatCompletionsStream(options.getModel(), options);
193128

194129
return Flux.fromStream(chatCompletionsStream.stream()
195130
// Note: the first chat completions can be ignored when using Azure OpenAI
@@ -205,7 +140,10 @@ public Flux<ChatResponse> stream(Prompt prompt) {
205140
}));
206141
}
207142

208-
private ChatCompletionsOptions toAzureChatCompletionsOptions(Prompt prompt) {
143+
/**
144+
* Test access.
145+
*/
146+
ChatCompletionsOptions toAzureChatCompletionsOptions(Prompt prompt) {
209147

210148
List<ChatRequestMessage> azureMessages = prompt.getInstructions()
211149
.stream()
@@ -214,10 +152,27 @@ private ChatCompletionsOptions toAzureChatCompletionsOptions(Prompt prompt) {
214152

215153
ChatCompletionsOptions options = new ChatCompletionsOptions(azureMessages);
216154

217-
options.setTemperature(this.getTemperature());
218-
options.setModel(this.getModel());
219-
options.setTopP(this.getTopP());
220-
options.setMaxTokens(this.getMaxTokens());
155+
if (this.defaultOptions != null) {
156+
// JSON merge doesn't due to Azure OpenAI service bug:
157+
// https://github.com/Azure/azure-sdk-for-java/issues/38183
158+
// options = ModelOptionsUtils.merge(options, this.defaultOptions,
159+
// ChatCompletionsOptions.class);
160+
options = merge(options, this.defaultOptions);
161+
}
162+
163+
if (prompt.getOptions() != null) {
164+
if (prompt.getOptions() instanceof AzureOpenAiChatOptions runtimeOptions) {
165+
// JSON merge doesn't due to Azure OpenAI service bug:
166+
// https://github.com/Azure/azure-sdk-for-java/issues/38183
167+
// options = ModelOptionsUtils.merge(runtimeOptions, options,
168+
// ChatCompletionsOptions.class);
169+
options = merge(runtimeOptions, options);
170+
}
171+
else {
172+
throw new IllegalArgumentException("Prompt options are not of type ChatCompletionsOptions:"
173+
+ prompt.getOptions().getClass().getSimpleName());
174+
}
175+
}
221176

222177
return options;
223178
}
@@ -256,4 +211,121 @@ private <T> List<T> nullSafeList(List<T> list) {
256211
return list != null ? list : Collections.emptyList();
257212
}
258213

214+
// JSON merge doesn't due to Azure OpenAI service bug:
215+
// https://github.com/Azure/azure-sdk-for-java/issues/38183
216+
private ChatCompletionsOptions merge(ChatCompletionsOptions azureOptions, AzureOpenAiChatOptions springAiOptions) {
217+
218+
if (springAiOptions == null) {
219+
return azureOptions;
220+
}
221+
222+
ChatCompletionsOptions mergedAzureOptions = new ChatCompletionsOptions(azureOptions.getMessages());
223+
mergedAzureOptions.setStream(azureOptions.isStream());
224+
225+
mergedAzureOptions.setMaxTokens(azureOptions.getMaxTokens());
226+
if (mergedAzureOptions.getMaxTokens() == null) {
227+
mergedAzureOptions.setMaxTokens(springAiOptions.getMaxTokens());
228+
}
229+
230+
mergedAzureOptions.setLogitBias(azureOptions.getLogitBias());
231+
if (mergedAzureOptions.getLogitBias() == null) {
232+
mergedAzureOptions.setLogitBias(springAiOptions.getLogitBias());
233+
}
234+
235+
mergedAzureOptions.setStop(azureOptions.getStop());
236+
if (mergedAzureOptions.getStop() == null) {
237+
mergedAzureOptions.setStop(springAiOptions.getStop());
238+
}
239+
240+
mergedAzureOptions.setTemperature(azureOptions.getTemperature());
241+
if (mergedAzureOptions.getTemperature() == null && springAiOptions.getTemperature() != null) {
242+
mergedAzureOptions.setTemperature(springAiOptions.getTemperature().doubleValue());
243+
}
244+
245+
mergedAzureOptions.setTopP(azureOptions.getTopP());
246+
if (mergedAzureOptions.getTopP() == null && springAiOptions.getTopP() != null) {
247+
mergedAzureOptions.setTopP(springAiOptions.getTopP().doubleValue());
248+
}
249+
250+
mergedAzureOptions.setFrequencyPenalty(azureOptions.getFrequencyPenalty());
251+
if (mergedAzureOptions.getFrequencyPenalty() == null && springAiOptions.getFrequencyPenalty() != null) {
252+
mergedAzureOptions.setFrequencyPenalty(springAiOptions.getFrequencyPenalty().doubleValue());
253+
}
254+
255+
mergedAzureOptions.setPresencePenalty(azureOptions.getPresencePenalty());
256+
if (mergedAzureOptions.getPresencePenalty() == null && springAiOptions.getPresencePenalty() != null) {
257+
mergedAzureOptions.setPresencePenalty(springAiOptions.getPresencePenalty().doubleValue());
258+
}
259+
260+
mergedAzureOptions.setN(azureOptions.getN());
261+
if (mergedAzureOptions.getN() == null) {
262+
mergedAzureOptions.setN(springAiOptions.getN());
263+
}
264+
265+
mergedAzureOptions.setUser(azureOptions.getUser());
266+
if (mergedAzureOptions.getUser() == null) {
267+
mergedAzureOptions.setUser(springAiOptions.getUser());
268+
}
269+
270+
mergedAzureOptions.setModel(azureOptions.getModel());
271+
if (mergedAzureOptions.getModel() == null) {
272+
mergedAzureOptions.setModel(springAiOptions.getModel());
273+
}
274+
275+
return mergedAzureOptions;
276+
}
277+
278+
// JSON merge doesn't due to Azure OpenAI service bug:
279+
// https://github.com/Azure/azure-sdk-for-java/issues/38183
280+
private ChatCompletionsOptions merge(AzureOpenAiChatOptions springAiOptions, ChatCompletionsOptions azureOptions) {
281+
if (springAiOptions == null) {
282+
return azureOptions;
283+
}
284+
285+
ChatCompletionsOptions mergedAzureOptions = new ChatCompletionsOptions(azureOptions.getMessages());
286+
mergedAzureOptions.setStream(azureOptions.isStream());
287+
288+
if (springAiOptions.getMaxTokens() != null) {
289+
mergedAzureOptions.setMaxTokens(springAiOptions.getMaxTokens());
290+
}
291+
292+
if (springAiOptions.getLogitBias() != null) {
293+
mergedAzureOptions.setLogitBias(springAiOptions.getLogitBias());
294+
}
295+
296+
if (springAiOptions.getStop() != null) {
297+
mergedAzureOptions.setStop(springAiOptions.getStop());
298+
}
299+
300+
if (springAiOptions.getTemperature() != null && springAiOptions.getTemperature() != null) {
301+
mergedAzureOptions.setTemperature(springAiOptions.getTemperature().doubleValue());
302+
}
303+
304+
if (springAiOptions.getTopP() != null && springAiOptions.getTopP() != null) {
305+
mergedAzureOptions.setTopP(springAiOptions.getTopP().doubleValue());
306+
}
307+
308+
if (springAiOptions.getFrequencyPenalty() != null && springAiOptions.getFrequencyPenalty() != null) {
309+
mergedAzureOptions.setFrequencyPenalty(springAiOptions.getFrequencyPenalty().doubleValue());
310+
}
311+
312+
if (springAiOptions.getPresencePenalty() != null && springAiOptions.getPresencePenalty() != null) {
313+
mergedAzureOptions.setPresencePenalty(springAiOptions.getPresencePenalty().doubleValue());
314+
}
315+
316+
if (springAiOptions.getN() != null) {
317+
mergedAzureOptions.setN(springAiOptions.getN());
318+
}
319+
320+
if (springAiOptions.getUser() != null) {
321+
mergedAzureOptions.setUser(springAiOptions.getUser());
322+
}
323+
324+
if (springAiOptions.getModel() != null) {
325+
mergedAzureOptions.setModel(springAiOptions.getModel());
326+
}
327+
328+
return mergedAzureOptions;
329+
}
330+
259331
}

0 commit comments

Comments
 (0)