15
15
*/
16
16
package org .springframework .ai .azure .openai ;
17
17
18
+ import java .util .ArrayList ;
19
+ import java .util .Collections ;
20
+ import java .util .HashSet ;
21
+ import java .util .List ;
22
+ import java .util .Map ;
23
+ import java .util .Optional ;
24
+ import java .util .Set ;
25
+ import java .util .concurrent .atomic .AtomicBoolean ;
26
+
27
+ import org .springframework .ai .azure .openai .metadata .AzureOpenAiUsage ;
28
+ import org .springframework .ai .chat .messages .AssistantMessage ;
29
+ import org .springframework .ai .chat .messages .Message ;
30
+ import org .springframework .ai .chat .messages .ToolResponseMessage ;
31
+ import org .springframework .ai .chat .messages .UserMessage ;
32
+ import org .springframework .ai .chat .metadata .ChatGenerationMetadata ;
33
+ import org .springframework .ai .chat .metadata .ChatResponseMetadata ;
34
+ import org .springframework .ai .chat .metadata .EmptyUsage ;
35
+ import org .springframework .ai .chat .metadata .PromptMetadata ;
36
+ import org .springframework .ai .chat .metadata .PromptMetadata .PromptFilterMetadata ;
37
+ import org .springframework .ai .chat .metadata .Usage ;
38
+ import org .springframework .ai .chat .model .AbstractToolCallSupport ;
39
+ import org .springframework .ai .chat .model .ChatModel ;
40
+ import org .springframework .ai .chat .model .ChatResponse ;
41
+ import org .springframework .ai .chat .model .Generation ;
42
+ import org .springframework .ai .chat .prompt .ChatOptions ;
43
+ import org .springframework .ai .chat .prompt .Prompt ;
44
+ import org .springframework .ai .model .ModelOptionsUtils ;
45
+ import org .springframework .ai .model .function .FunctionCallbackContext ;
46
+ import org .springframework .util .Assert ;
47
+ import org .springframework .util .CollectionUtils ;
48
+
18
49
import com .azure .ai .openai .OpenAIClient ;
19
50
import com .azure .ai .openai .models .ChatChoice ;
20
51
import com .azure .ai .openai .models .ChatCompletions ;
41
72
import com .azure .ai .openai .models .FunctionDefinition ;
42
73
import com .azure .core .util .BinaryData ;
43
74
import com .azure .core .util .IterableStream ;
44
- import org .springframework .ai .azure .openai .metadata .AzureOpenAiUsage ;
45
- import org .springframework .ai .chat .messages .AssistantMessage ;
46
- import org .springframework .ai .chat .messages .Message ;
47
- import org .springframework .ai .chat .messages .ToolResponseMessage ;
48
- import org .springframework .ai .chat .messages .UserMessage ;
49
- import org .springframework .ai .chat .metadata .ChatGenerationMetadata ;
50
- import org .springframework .ai .chat .metadata .ChatResponseMetadata ;
51
- import org .springframework .ai .chat .metadata .PromptMetadata ;
52
- import org .springframework .ai .chat .metadata .PromptMetadata .PromptFilterMetadata ;
53
- import org .springframework .ai .chat .model .ChatModel ;
54
- import org .springframework .ai .chat .model .ChatResponse ;
55
- import org .springframework .ai .chat .model .Generation ;
56
- import org .springframework .ai .chat .prompt .ChatOptions ;
57
- import org .springframework .ai .chat .prompt .Prompt ;
58
- import org .springframework .ai .model .ModelOptionsUtils ;
59
- import org .springframework .ai .chat .model .AbstractToolCallSupport ;
60
- import org .springframework .ai .model .function .FunctionCallbackContext ;
61
- import org .springframework .util .Assert ;
62
- import org .springframework .util .CollectionUtils ;
75
+
63
76
import reactor .core .publisher .Flux ;
64
77
import reactor .core .publisher .Mono ;
65
78
66
- import java .util .ArrayList ;
67
- import java .util .Collections ;
68
- import java .util .HashSet ;
69
- import java .util .List ;
70
- import java .util .Map ;
71
- import java .util .Optional ;
72
- import java .util .Set ;
73
- import java .util .concurrent .atomic .AtomicBoolean ;
74
-
75
79
/**
76
80
* {@link ChatModel} implementation for {@literal Microsoft Azure AI} backed by
77
81
* {@link OpenAIClient}.
@@ -136,37 +140,16 @@ public ChatResponse call(Prompt prompt) {
136
140
137
141
ChatCompletions chatCompletions = this .openAIClient .getChatCompletions (options .getModel (), options );
138
142
139
- if (isToolFunctionCall (chatCompletions )) {
140
- List <Message > toolCallMessageConversation = this .handleToolCallRequests (prompt .getInstructions (),
141
- chatCompletions );
143
+ ChatResponse chatResponse = toChatResponse (chatCompletions );
144
+
145
+ if (isToolCall (chatResponse , Set .of (String .valueOf (CompletionsFinishReason .TOOL_CALLS ).toLowerCase ()))) {
146
+ var toolCallConversation = handleToolCalls (prompt , chatResponse );
142
147
// Recursively call the call method with the tool call message
143
148
// conversation that contains the call responses.
144
- return this .call (new Prompt (toolCallMessageConversation , prompt .getOptions ()));
149
+ return this .call (new Prompt (toolCallConversation , prompt .getOptions ()));
145
150
}
146
151
147
- List <Generation > generations = nullSafeList (chatCompletions .getChoices ()).stream ()
148
- .map (choice -> new Generation (choice .getMessage ().getContent ())
149
- .withGenerationMetadata (generateChoiceMetadata (choice )))
150
- .toList ();
151
-
152
- PromptMetadata promptFilterMetadata = generatePromptMetadata (chatCompletions );
153
-
154
- return new ChatResponse (generations , from (chatCompletions , promptFilterMetadata ));
155
- }
156
-
157
- public static ChatResponseMetadata from (ChatCompletions chatCompletions , PromptMetadata promptFilterMetadata ) {
158
- Assert .notNull (chatCompletions , "Azure OpenAI ChatCompletions must not be null" );
159
- String id = chatCompletions .getId ();
160
- AzureOpenAiUsage usage = AzureOpenAiUsage .from (chatCompletions );
161
- ChatResponseMetadata chatResponseMetadata = ChatResponseMetadata .builder ()
162
- .withId (id )
163
- .withUsage (usage )
164
- .withModel (chatCompletions .getModel ())
165
- .withPromptMetadata (promptFilterMetadata )
166
- .withKeyValue ("system-fingerprint" , chatCompletions .getSystemFingerprint ())
167
- .build ();
168
-
169
- return chatResponseMetadata ;
152
+ return chatResponse ;
170
153
}
171
154
172
155
@ Override
@@ -179,10 +162,9 @@ public Flux<ChatResponse> stream(Prompt prompt) {
179
162
.getChatCompletionsStream (options .getModel (), options );
180
163
181
164
final var isFunctionCall = new AtomicBoolean (false );
182
- final var accessibleChatCompletionsFlux = Flux .fromIterable (chatCompletionsStream )
165
+ final Flux < ChatCompletions > accessibleChatCompletionsFlux = Flux .fromIterable (chatCompletionsStream )
183
166
// Note: the first chat completions can be ignored when using Azure OpenAI
184
167
// service which is a known service bug.
185
- // .skip(1)
186
168
.filter (chatCompletions -> !CollectionUtils .isEmpty (chatCompletions .getChoices ()))
187
169
.map (chatCompletions -> {
188
170
final var toolCalls = chatCompletions .getChoices ().get (0 ).getDelta ().getToolCalls ();
@@ -204,58 +186,70 @@ public Flux<ChatResponse> stream(Prompt prompt) {
204
186
})
205
187
.flatMap (mono -> mono );
206
188
207
- return accessibleChatCompletionsFlux .switchMap (chatCompletion -> {
208
- if (isToolFunctionCall (chatCompletion )) {
209
- List <Message > toolCallMessageConversation = this .handleToolCallRequests (prompt .getInstructions (),
210
- chatCompletion );
211
- return this .stream (new Prompt (toolCallMessageConversation , prompt .getOptions ()));
189
+ return accessibleChatCompletionsFlux .switchMap (chatCompletions -> {
190
+
191
+ ChatResponse chatResponse = toChatResponse (chatCompletions );
192
+
193
+ if (isToolCall (chatResponse , Set .of (String .valueOf (CompletionsFinishReason .TOOL_CALLS ).toLowerCase ()))) {
194
+ var toolCallConversation = handleToolCalls (prompt , chatResponse );
195
+ // Recursively call the call method with the tool call message
196
+ // conversation that contains the call responses.
197
+ return this .stream (new Prompt (toolCallConversation , prompt .getOptions ()));
212
198
}
213
199
214
- return Mono .just (chatCompletion ).flatMapIterable (ChatCompletions ::getChoices ).map (choice -> {
215
- var content = Optional .ofNullable (choice .getMessage ()).orElse (choice .getDelta ()).getContent ();
216
- var generation = new Generation (content ).withGenerationMetadata (generateChoiceMetadata (choice ));
217
- return new ChatResponse (List .of (generation ));
218
- });
200
+ return Mono .just (chatResponse );
219
201
});
220
202
}
221
203
222
- private List <Message > handleToolCallRequests (List <Message > previousMessages , ChatCompletions chatCompletion ) {
204
+ private ChatResponse toChatResponse (ChatCompletions chatCompletions ) {
205
+
206
+ List <Generation > generations = nullSafeList (chatCompletions .getChoices ()).stream ().map (choice -> {
207
+ // @formatter:off
208
+ Map <String , Object > metadata = Map .of (
209
+ "id" , chatCompletions .getId () != null ? chatCompletions .getId () : "" ,
210
+ "choiceIndex" , choice .getIndex (),
211
+ "finishReason" , choice .getFinishReason () != null ? String .valueOf (choice .getFinishReason ()) : "" );
212
+ // @formatter:on
213
+ return buildGeneration (choice , metadata );
214
+ }).toList ();
223
215
224
- ChatRequestAssistantMessage nativeAssistantMessage = this . extractAssistantMessage ( chatCompletion );
216
+ PromptMetadata promptFilterMetadata = generatePromptMetadata ( chatCompletions );
225
217
226
- List <AssistantMessage .ToolCall > assistantToolCalls = nativeAssistantMessage .getToolCalls ()
227
- .stream ()
228
- .map (tc -> (ChatCompletionsFunctionToolCall ) tc )
229
- .map (toolCall -> new AssistantMessage .ToolCall (toolCall .getId (), toolCall .getType (),
230
- toolCall .getFunction ().getName (), toolCall .getFunction ().getArguments ()))
231
- .toList ();
218
+ return new ChatResponse (generations , from (chatCompletions , promptFilterMetadata ));
219
+ }
220
+
221
+ private Generation buildGeneration (ChatChoice choice , Map <String , Object > metadata ) {
232
222
233
- AssistantMessage assistantMessage = new AssistantMessage (nativeAssistantMessage .getContent (), Map .of (),
234
- assistantToolCalls );
223
+ var responseMessage = Optional .ofNullable (choice .getMessage ()).orElse (choice .getDelta ());
235
224
236
- ToolResponseMessage toolResponseMessage = this .executeFunctions (assistantMessage );
225
+ List <AssistantMessage .ToolCall > toolCalls = responseMessage .getToolCalls () == null ? List .of ()
226
+ : responseMessage .getToolCalls ().stream ().map (toolCall -> {
227
+ final var tc1 = (ChatCompletionsFunctionToolCall ) toolCall ;
228
+ String id = tc1 .getId ();
229
+ String name = tc1 .getFunction ().getName ();
230
+ String arguments = tc1 .getFunction ().getArguments ();
231
+ return new AssistantMessage .ToolCall (id , "function" , name , arguments );
232
+ }).toList ();
237
233
238
- // History
239
- List <Message > messages = new ArrayList <>(previousMessages );
240
- messages .add (assistantMessage );
241
- messages .add (toolResponseMessage );
234
+ var assistantMessage = new AssistantMessage (responseMessage .getContent (), metadata , toolCalls );
235
+ var generationMetadata = generateChoiceMetadata (choice );
242
236
243
- return messages ;
237
+ return new Generation ( assistantMessage , generationMetadata ) ;
244
238
}
245
239
246
- private ChatRequestAssistantMessage extractAssistantMessage (ChatCompletions response ) {
247
- final var accessibleChatChoice = response . getChoices (). get ( 0 );
248
- var responseMessage = Optional . ofNullable ( accessibleChatChoice . getMessage ())
249
- . orElse ( accessibleChatChoice . getDelta () );
250
- ChatRequestAssistantMessage assistantMessage = new ChatRequestAssistantMessage ( "" );
251
- final var toolCalls = responseMessage . getToolCalls ();
252
- assistantMessage . setToolCalls ( toolCalls . stream (). map ( tc -> {
253
- final var tc1 = ( ChatCompletionsFunctionToolCall ) tc ;
254
- var toDowncast = new ChatCompletionsFunctionToolCall ( tc . getId (),
255
- new FunctionCall ( tc1 . getFunction (). getName (), tc1 . getFunction (). getArguments ()));
256
- return (( ChatCompletionsToolCall ) toDowncast );
257
- }). toList ());
258
- return assistantMessage ;
240
+ public static ChatResponseMetadata from (ChatCompletions chatCompletions , PromptMetadata promptFilterMetadata ) {
241
+ Assert . notNull ( chatCompletions , "Azure OpenAI ChatCompletions must not be null" );
242
+ String id = chatCompletions . getId ();
243
+ Usage usage = ( chatCompletions . getUsage () != null ) ? AzureOpenAiUsage . from ( chatCompletions ) : new EmptyUsage ( );
244
+ ChatResponseMetadata chatResponseMetadata = ChatResponseMetadata . builder ()
245
+ . withId ( id )
246
+ . withUsage ( usage )
247
+ . withModel ( chatCompletions . getModel ())
248
+ . withPromptMetadata ( promptFilterMetadata )
249
+ . withKeyValue ( "system-fingerprint" , chatCompletions . getSystemFingerprint ())
250
+ . build ( );
251
+
252
+ return chatResponseMetadata ;
259
253
}
260
254
261
255
/**
@@ -560,21 +554,6 @@ private ChatCompletionsOptions copy(ChatCompletionsOptions fromOptions) {
560
554
return copyOptions ;
561
555
}
562
556
563
- protected boolean isToolFunctionCall (ChatCompletions chatCompletions ) {
564
-
565
- if (chatCompletions == null || CollectionUtils .isEmpty (chatCompletions .getChoices ())) {
566
- return false ;
567
- }
568
-
569
- var choice = chatCompletions .getChoices ().get (0 );
570
-
571
- if (choice == null || choice .getFinishReason () == null ) {
572
- return false ;
573
- }
574
-
575
- return choice .getFinishReason () == CompletionsFinishReason .TOOL_CALLS ;
576
- }
577
-
578
557
/**
579
558
* Maps the SpringAI response format to the Azure response format
580
559
* @param responseFormat SpringAI response format
0 commit comments