1
1
/*
2
- * Copyright 2023 the original author or authors.
2
+ * Copyright 2023-2024 the original author or authors.
3
3
*
4
4
* Licensed under the Apache License, Version 2.0 (the "License");
5
5
* you may not use this file except in compliance with the License.
27
27
import com .azure .ai .openai .models .ChatRequestMessage ;
28
28
import com .azure .ai .openai .models .ChatRequestSystemMessage ;
29
29
import com .azure .ai .openai .models .ChatRequestUserMessage ;
30
- import com .azure .ai .openai .models .ChatResponseMessage ;
31
30
import com .azure .ai .openai .models .ContentFilterResultsForPrompt ;
32
31
import com .azure .core .util .IterableStream ;
33
32
import org .slf4j .Logger ;
34
33
import org .slf4j .LoggerFactory ;
35
- import org .springframework .ai .chat .metadata .ChatGenerationMetadata ;
36
34
import reactor .core .publisher .Flux ;
37
35
38
36
import org .springframework .ai .azure .openai .metadata .AzureOpenAiChatResponseMetadata ;
39
37
import org .springframework .ai .chat .ChatClient ;
40
38
import org .springframework .ai .chat .ChatResponse ;
41
39
import org .springframework .ai .chat .Generation ;
42
40
import org .springframework .ai .chat .StreamingChatClient ;
41
+ import org .springframework .ai .chat .messages .Message ;
42
+ import org .springframework .ai .chat .metadata .ChatGenerationMetadata ;
43
43
import org .springframework .ai .chat .metadata .PromptMetadata ;
44
44
import org .springframework .ai .chat .metadata .PromptMetadata .PromptFilterMetadata ;
45
45
import org .springframework .ai .chat .prompt .Prompt ;
46
- import org .springframework .ai .chat .messages .Message ;
47
46
import org .springframework .util .Assert ;
48
47
49
48
/**
59
58
*/
60
59
public class AzureOpenAiChatClient implements ChatClient , StreamingChatClient {
61
60
62
- /**
63
- * The sampling temperature to use that controls the apparent creativity of generated
64
- * completions. Higher values will make output more random while lower values will
65
- * make results more focused and deterministic. It is not recommended to modify
66
- * temperature and top_p for the same completions request as the interaction of these
67
- * two settings is difficult to predict.
68
- */
69
- private Double temperature = 0.7 ;
61
+ private static final String DEFAULT_MODEL = "gpt-35-turbo" ;
70
62
71
- /**
72
- * An alternative to sampling with temperature called nucleus sampling. This value
73
- * causes the model to consider the results of tokens with the provided probability
74
- * mass. As an example, a value of 0.15 will cause only the tokens comprising the top
75
- * 15% of probability mass to be considered. It is not recommended to modify
76
- * temperature and top_p for the same completions request as the interaction of these
77
- * two settings is difficult to predict.
78
- */
79
- private Double topP ;
63
+ private static final Float DEFAULT_TEMPERATURE = 0.7f ;
64
+
65
+ private final Logger logger = LoggerFactory .getLogger (getClass ());
80
66
81
67
/**
82
- * Creates an instance of ChatCompletionsOptions class .
68
+ * The configuration information for a chat completions request .
83
69
*/
84
- private String model = "gpt-35-turbo" ;
70
+ private AzureOpenAiChatOptions defaultOptions ;
85
71
86
72
/**
87
- * The maximum number of tokens to generate .
73
+ * The {@link OpenAIClient} used to interact with the Azure OpenAI service .
88
74
*/
89
- private Integer maxTokens ;
90
-
91
- private final Logger logger = LoggerFactory .getLogger (getClass ());
92
-
93
75
private final OpenAIClient openAIClient ;
94
76
95
77
public AzureOpenAiChatClient (OpenAIClient microsoftOpenAiClient ) {
96
78
Assert .notNull (microsoftOpenAiClient , "com.azure.ai.openai.OpenAIClient must not be null" );
97
79
this .openAIClient = microsoftOpenAiClient ;
80
+ this .defaultOptions = AzureOpenAiChatOptions .builder ()
81
+ .withModel (DEFAULT_MODEL )
82
+ .withTemperature (DEFAULT_TEMPERATURE )
83
+ .build ();
98
84
}
99
85
100
- public String getModel () {
101
- return this .model ;
102
- }
103
-
104
- public AzureOpenAiChatClient withModel (String model ) {
105
- this .model = model ;
106
- return this ;
107
- }
108
-
109
- public Double getTemperature () {
110
- return this .temperature ;
111
- }
112
-
113
- public AzureOpenAiChatClient withTemperature (Double temperature ) {
114
- this .temperature = temperature ;
115
- return this ;
116
- }
117
-
118
- public Double getTopP () {
119
- return topP ;
120
- }
121
-
122
- public AzureOpenAiChatClient withTopP (Double topP ) {
123
- this .topP = topP ;
86
+ public AzureOpenAiChatClient withDefaultOptions (AzureOpenAiChatOptions defaultOptions ) {
87
+ Assert .notNull (defaultOptions , "DefaultOptions must not be null" );
88
+ this .defaultOptions = defaultOptions ;
124
89
return this ;
125
90
}
126
91
127
- public Integer getMaxTokens () {
128
- return maxTokens ;
129
- }
130
-
131
- public AzureOpenAiChatClient withMaxTokens (Integer maxTokens ) {
132
- this .maxTokens = maxTokens ;
133
- return this ;
134
- }
135
-
136
- @ Override
137
- public String call (String text ) {
138
-
139
- ChatRequestMessage azureChatMessage = new ChatRequestUserMessage (text );
140
-
141
- ChatCompletionsOptions options = new ChatCompletionsOptions (List .of (azureChatMessage ));
142
- options .setTemperature (this .getTemperature ());
143
- options .setModel (this .getModel ());
144
-
145
- logger .trace ("Azure Chat Message: {}" , azureChatMessage );
146
-
147
- ChatCompletions chatCompletions = this .openAIClient .getChatCompletions (this .getModel (), options );
148
- logger .trace ("Azure ChatCompletions: {}" , chatCompletions );
149
-
150
- StringBuilder stringBuilder = new StringBuilder ();
151
-
152
- for (ChatChoice choice : chatCompletions .getChoices ()) {
153
- ChatResponseMessage message = choice .getMessage ();
154
- if (message != null && message .getContent () != null ) {
155
- stringBuilder .append (message .getContent ());
156
- }
157
- }
158
-
159
- return stringBuilder .toString ();
92
+ public AzureOpenAiChatOptions getDefaultOptions () {
93
+ return defaultOptions ;
160
94
}
161
95
162
96
@ Override
@@ -167,7 +101,7 @@ public ChatResponse call(Prompt prompt) {
167
101
168
102
logger .trace ("Azure ChatCompletionsOptions: {}" , options );
169
103
170
- ChatCompletions chatCompletions = this .openAIClient .getChatCompletions (this .getModel (), options );
104
+ ChatCompletions chatCompletions = this .openAIClient .getChatCompletions (options .getModel (), options );
171
105
172
106
logger .trace ("Azure ChatCompletions: {}" , chatCompletions );
173
107
@@ -178,6 +112,7 @@ public ChatResponse call(Prompt prompt) {
178
112
.toList ();
179
113
180
114
PromptMetadata promptFilterMetadata = generatePromptMetadata (chatCompletions );
115
+
181
116
return new ChatResponse (generations ,
182
117
AzureOpenAiChatResponseMetadata .from (chatCompletions , promptFilterMetadata ));
183
118
}
@@ -189,7 +124,7 @@ public Flux<ChatResponse> stream(Prompt prompt) {
189
124
options .setStream (true );
190
125
191
126
IterableStream <ChatCompletions > chatCompletionsStream = this .openAIClient
192
- .getChatCompletionsStream (this .getModel (), options );
127
+ .getChatCompletionsStream (options .getModel (), options );
193
128
194
129
return Flux .fromStream (chatCompletionsStream .stream ()
195
130
// Note: the first chat completions can be ignored when using Azure OpenAI
@@ -205,7 +140,10 @@ public Flux<ChatResponse> stream(Prompt prompt) {
205
140
}));
206
141
}
207
142
208
- private ChatCompletionsOptions toAzureChatCompletionsOptions (Prompt prompt ) {
143
+ /**
144
+ * Test access.
145
+ */
146
+ ChatCompletionsOptions toAzureChatCompletionsOptions (Prompt prompt ) {
209
147
210
148
List <ChatRequestMessage > azureMessages = prompt .getInstructions ()
211
149
.stream ()
@@ -214,10 +152,27 @@ private ChatCompletionsOptions toAzureChatCompletionsOptions(Prompt prompt) {
214
152
215
153
ChatCompletionsOptions options = new ChatCompletionsOptions (azureMessages );
216
154
217
- options .setTemperature (this .getTemperature ());
218
- options .setModel (this .getModel ());
219
- options .setTopP (this .getTopP ());
220
- options .setMaxTokens (this .getMaxTokens ());
155
+ if (this .defaultOptions != null ) {
156
+ // JSON merge doesn't due to Azure OpenAI service bug:
157
+ // https://github.com/Azure/azure-sdk-for-java/issues/38183
158
+ // options = ModelOptionsUtils.merge(options, this.defaultOptions,
159
+ // ChatCompletionsOptions.class);
160
+ options = merge (options , this .defaultOptions );
161
+ }
162
+
163
+ if (prompt .getOptions () != null ) {
164
+ if (prompt .getOptions () instanceof AzureOpenAiChatOptions runtimeOptions ) {
165
+ // JSON merge doesn't due to Azure OpenAI service bug:
166
+ // https://github.com/Azure/azure-sdk-for-java/issues/38183
167
+ // options = ModelOptionsUtils.merge(runtimeOptions, options,
168
+ // ChatCompletionsOptions.class);
169
+ options = merge (runtimeOptions , options );
170
+ }
171
+ else {
172
+ throw new IllegalArgumentException ("Prompt options are not of type ChatCompletionsOptions:"
173
+ + prompt .getOptions ().getClass ().getSimpleName ());
174
+ }
175
+ }
221
176
222
177
return options ;
223
178
}
@@ -256,4 +211,121 @@ private <T> List<T> nullSafeList(List<T> list) {
256
211
return list != null ? list : Collections .emptyList ();
257
212
}
258
213
214
+ // JSON merge doesn't due to Azure OpenAI service bug:
215
+ // https://github.com/Azure/azure-sdk-for-java/issues/38183
216
+ private ChatCompletionsOptions merge (ChatCompletionsOptions azureOptions , AzureOpenAiChatOptions springAiOptions ) {
217
+
218
+ if (springAiOptions == null ) {
219
+ return azureOptions ;
220
+ }
221
+
222
+ ChatCompletionsOptions mergedAzureOptions = new ChatCompletionsOptions (azureOptions .getMessages ());
223
+ mergedAzureOptions .setStream (azureOptions .isStream ());
224
+
225
+ mergedAzureOptions .setMaxTokens (azureOptions .getMaxTokens ());
226
+ if (mergedAzureOptions .getMaxTokens () == null ) {
227
+ mergedAzureOptions .setMaxTokens (springAiOptions .getMaxTokens ());
228
+ }
229
+
230
+ mergedAzureOptions .setLogitBias (azureOptions .getLogitBias ());
231
+ if (mergedAzureOptions .getLogitBias () == null ) {
232
+ mergedAzureOptions .setLogitBias (springAiOptions .getLogitBias ());
233
+ }
234
+
235
+ mergedAzureOptions .setStop (azureOptions .getStop ());
236
+ if (mergedAzureOptions .getStop () == null ) {
237
+ mergedAzureOptions .setStop (springAiOptions .getStop ());
238
+ }
239
+
240
+ mergedAzureOptions .setTemperature (azureOptions .getTemperature ());
241
+ if (mergedAzureOptions .getTemperature () == null && springAiOptions .getTemperature () != null ) {
242
+ mergedAzureOptions .setTemperature (springAiOptions .getTemperature ().doubleValue ());
243
+ }
244
+
245
+ mergedAzureOptions .setTopP (azureOptions .getTopP ());
246
+ if (mergedAzureOptions .getTopP () == null && springAiOptions .getTopP () != null ) {
247
+ mergedAzureOptions .setTopP (springAiOptions .getTopP ().doubleValue ());
248
+ }
249
+
250
+ mergedAzureOptions .setFrequencyPenalty (azureOptions .getFrequencyPenalty ());
251
+ if (mergedAzureOptions .getFrequencyPenalty () == null && springAiOptions .getFrequencyPenalty () != null ) {
252
+ mergedAzureOptions .setFrequencyPenalty (springAiOptions .getFrequencyPenalty ().doubleValue ());
253
+ }
254
+
255
+ mergedAzureOptions .setPresencePenalty (azureOptions .getPresencePenalty ());
256
+ if (mergedAzureOptions .getPresencePenalty () == null && springAiOptions .getPresencePenalty () != null ) {
257
+ mergedAzureOptions .setPresencePenalty (springAiOptions .getPresencePenalty ().doubleValue ());
258
+ }
259
+
260
+ mergedAzureOptions .setN (azureOptions .getN ());
261
+ if (mergedAzureOptions .getN () == null ) {
262
+ mergedAzureOptions .setN (springAiOptions .getN ());
263
+ }
264
+
265
+ mergedAzureOptions .setUser (azureOptions .getUser ());
266
+ if (mergedAzureOptions .getUser () == null ) {
267
+ mergedAzureOptions .setUser (springAiOptions .getUser ());
268
+ }
269
+
270
+ mergedAzureOptions .setModel (azureOptions .getModel ());
271
+ if (mergedAzureOptions .getModel () == null ) {
272
+ mergedAzureOptions .setModel (springAiOptions .getModel ());
273
+ }
274
+
275
+ return mergedAzureOptions ;
276
+ }
277
+
278
+ // JSON merge doesn't due to Azure OpenAI service bug:
279
+ // https://github.com/Azure/azure-sdk-for-java/issues/38183
280
+ private ChatCompletionsOptions merge (AzureOpenAiChatOptions springAiOptions , ChatCompletionsOptions azureOptions ) {
281
+ if (springAiOptions == null ) {
282
+ return azureOptions ;
283
+ }
284
+
285
+ ChatCompletionsOptions mergedAzureOptions = new ChatCompletionsOptions (azureOptions .getMessages ());
286
+ mergedAzureOptions .setStream (azureOptions .isStream ());
287
+
288
+ if (springAiOptions .getMaxTokens () != null ) {
289
+ mergedAzureOptions .setMaxTokens (springAiOptions .getMaxTokens ());
290
+ }
291
+
292
+ if (springAiOptions .getLogitBias () != null ) {
293
+ mergedAzureOptions .setLogitBias (springAiOptions .getLogitBias ());
294
+ }
295
+
296
+ if (springAiOptions .getStop () != null ) {
297
+ mergedAzureOptions .setStop (springAiOptions .getStop ());
298
+ }
299
+
300
+ if (springAiOptions .getTemperature () != null && springAiOptions .getTemperature () != null ) {
301
+ mergedAzureOptions .setTemperature (springAiOptions .getTemperature ().doubleValue ());
302
+ }
303
+
304
+ if (springAiOptions .getTopP () != null && springAiOptions .getTopP () != null ) {
305
+ mergedAzureOptions .setTopP (springAiOptions .getTopP ().doubleValue ());
306
+ }
307
+
308
+ if (springAiOptions .getFrequencyPenalty () != null && springAiOptions .getFrequencyPenalty () != null ) {
309
+ mergedAzureOptions .setFrequencyPenalty (springAiOptions .getFrequencyPenalty ().doubleValue ());
310
+ }
311
+
312
+ if (springAiOptions .getPresencePenalty () != null && springAiOptions .getPresencePenalty () != null ) {
313
+ mergedAzureOptions .setPresencePenalty (springAiOptions .getPresencePenalty ().doubleValue ());
314
+ }
315
+
316
+ if (springAiOptions .getN () != null ) {
317
+ mergedAzureOptions .setN (springAiOptions .getN ());
318
+ }
319
+
320
+ if (springAiOptions .getUser () != null ) {
321
+ mergedAzureOptions .setUser (springAiOptions .getUser ());
322
+ }
323
+
324
+ if (springAiOptions .getModel () != null ) {
325
+ mergedAzureOptions .setModel (springAiOptions .getModel ());
326
+ }
327
+
328
+ return mergedAzureOptions ;
329
+ }
330
+
259
331
}
0 commit comments