Skip to content

Commit 736bd24

Browse files
committed
Streamline Azure OpenAI Function Calling
1 parent 9aee4aa commit 736bd24

File tree

3 files changed

+92
-113
lines changed

3 files changed

+92
-113
lines changed

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java

Lines changed: 90 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,37 @@
1515
*/
1616
package org.springframework.ai.azure.openai;
1717

18+
import java.util.ArrayList;
19+
import java.util.Collections;
20+
import java.util.HashSet;
21+
import java.util.List;
22+
import java.util.Map;
23+
import java.util.Optional;
24+
import java.util.Set;
25+
import java.util.concurrent.atomic.AtomicBoolean;
26+
27+
import org.springframework.ai.azure.openai.metadata.AzureOpenAiUsage;
28+
import org.springframework.ai.chat.messages.AssistantMessage;
29+
import org.springframework.ai.chat.messages.Message;
30+
import org.springframework.ai.chat.messages.ToolResponseMessage;
31+
import org.springframework.ai.chat.messages.UserMessage;
32+
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
33+
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
34+
import org.springframework.ai.chat.metadata.EmptyUsage;
35+
import org.springframework.ai.chat.metadata.PromptMetadata;
36+
import org.springframework.ai.chat.metadata.PromptMetadata.PromptFilterMetadata;
37+
import org.springframework.ai.chat.metadata.Usage;
38+
import org.springframework.ai.chat.model.AbstractToolCallSupport;
39+
import org.springframework.ai.chat.model.ChatModel;
40+
import org.springframework.ai.chat.model.ChatResponse;
41+
import org.springframework.ai.chat.model.Generation;
42+
import org.springframework.ai.chat.prompt.ChatOptions;
43+
import org.springframework.ai.chat.prompt.Prompt;
44+
import org.springframework.ai.model.ModelOptionsUtils;
45+
import org.springframework.ai.model.function.FunctionCallbackContext;
46+
import org.springframework.util.Assert;
47+
import org.springframework.util.CollectionUtils;
48+
1849
import com.azure.ai.openai.OpenAIClient;
1950
import com.azure.ai.openai.models.ChatChoice;
2051
import com.azure.ai.openai.models.ChatCompletions;
@@ -41,37 +72,10 @@
4172
import com.azure.ai.openai.models.FunctionDefinition;
4273
import com.azure.core.util.BinaryData;
4374
import com.azure.core.util.IterableStream;
44-
import org.springframework.ai.azure.openai.metadata.AzureOpenAiUsage;
45-
import org.springframework.ai.chat.messages.AssistantMessage;
46-
import org.springframework.ai.chat.messages.Message;
47-
import org.springframework.ai.chat.messages.ToolResponseMessage;
48-
import org.springframework.ai.chat.messages.UserMessage;
49-
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
50-
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
51-
import org.springframework.ai.chat.metadata.PromptMetadata;
52-
import org.springframework.ai.chat.metadata.PromptMetadata.PromptFilterMetadata;
53-
import org.springframework.ai.chat.model.ChatModel;
54-
import org.springframework.ai.chat.model.ChatResponse;
55-
import org.springframework.ai.chat.model.Generation;
56-
import org.springframework.ai.chat.prompt.ChatOptions;
57-
import org.springframework.ai.chat.prompt.Prompt;
58-
import org.springframework.ai.model.ModelOptionsUtils;
59-
import org.springframework.ai.chat.model.AbstractToolCallSupport;
60-
import org.springframework.ai.model.function.FunctionCallbackContext;
61-
import org.springframework.util.Assert;
62-
import org.springframework.util.CollectionUtils;
75+
6376
import reactor.core.publisher.Flux;
6477
import reactor.core.publisher.Mono;
6578

66-
import java.util.ArrayList;
67-
import java.util.Collections;
68-
import java.util.HashSet;
69-
import java.util.List;
70-
import java.util.Map;
71-
import java.util.Optional;
72-
import java.util.Set;
73-
import java.util.concurrent.atomic.AtomicBoolean;
74-
7579
/**
7680
* {@link ChatModel} implementation for {@literal Microsoft Azure AI} backed by
7781
* {@link OpenAIClient}.
@@ -136,37 +140,16 @@ public ChatResponse call(Prompt prompt) {
136140

137141
ChatCompletions chatCompletions = this.openAIClient.getChatCompletions(options.getModel(), options);
138142

139-
if (isToolFunctionCall(chatCompletions)) {
140-
List<Message> toolCallMessageConversation = this.handleToolCallRequests(prompt.getInstructions(),
141-
chatCompletions);
143+
ChatResponse chatResponse = toChatResponse(chatCompletions);
144+
145+
if (isToolCall(chatResponse, Set.of(String.valueOf(CompletionsFinishReason.TOOL_CALLS).toLowerCase()))) {
146+
var toolCallConversation = handleToolCalls(prompt, chatResponse);
142147
// Recursively call the call method with the tool call message
143148
// conversation that contains the call responses.
144-
return this.call(new Prompt(toolCallMessageConversation, prompt.getOptions()));
149+
return this.call(new Prompt(toolCallConversation, prompt.getOptions()));
145150
}
146151

147-
List<Generation> generations = nullSafeList(chatCompletions.getChoices()).stream()
148-
.map(choice -> new Generation(choice.getMessage().getContent())
149-
.withGenerationMetadata(generateChoiceMetadata(choice)))
150-
.toList();
151-
152-
PromptMetadata promptFilterMetadata = generatePromptMetadata(chatCompletions);
153-
154-
return new ChatResponse(generations, from(chatCompletions, promptFilterMetadata));
155-
}
156-
157-
public static ChatResponseMetadata from(ChatCompletions chatCompletions, PromptMetadata promptFilterMetadata) {
158-
Assert.notNull(chatCompletions, "Azure OpenAI ChatCompletions must not be null");
159-
String id = chatCompletions.getId();
160-
AzureOpenAiUsage usage = AzureOpenAiUsage.from(chatCompletions);
161-
ChatResponseMetadata chatResponseMetadata = ChatResponseMetadata.builder()
162-
.withId(id)
163-
.withUsage(usage)
164-
.withModel(chatCompletions.getModel())
165-
.withPromptMetadata(promptFilterMetadata)
166-
.withKeyValue("system-fingerprint", chatCompletions.getSystemFingerprint())
167-
.build();
168-
169-
return chatResponseMetadata;
152+
return chatResponse;
170153
}
171154

172155
@Override
@@ -179,10 +162,9 @@ public Flux<ChatResponse> stream(Prompt prompt) {
179162
.getChatCompletionsStream(options.getModel(), options);
180163

181164
final var isFunctionCall = new AtomicBoolean(false);
182-
final var accessibleChatCompletionsFlux = Flux.fromIterable(chatCompletionsStream)
165+
final Flux<ChatCompletions> accessibleChatCompletionsFlux = Flux.fromIterable(chatCompletionsStream)
183166
// Note: the first chat completions can be ignored when using Azure OpenAI
184167
// service which is a known service bug.
185-
// .skip(1)
186168
.filter(chatCompletions -> !CollectionUtils.isEmpty(chatCompletions.getChoices()))
187169
.map(chatCompletions -> {
188170
final var toolCalls = chatCompletions.getChoices().get(0).getDelta().getToolCalls();
@@ -204,58 +186,70 @@ public Flux<ChatResponse> stream(Prompt prompt) {
204186
})
205187
.flatMap(mono -> mono);
206188

207-
return accessibleChatCompletionsFlux.switchMap(chatCompletion -> {
208-
if (isToolFunctionCall(chatCompletion)) {
209-
List<Message> toolCallMessageConversation = this.handleToolCallRequests(prompt.getInstructions(),
210-
chatCompletion);
211-
return this.stream(new Prompt(toolCallMessageConversation, prompt.getOptions()));
189+
return accessibleChatCompletionsFlux.switchMap(chatCompletions -> {
190+
191+
ChatResponse chatResponse = toChatResponse(chatCompletions);
192+
193+
if (isToolCall(chatResponse, Set.of(String.valueOf(CompletionsFinishReason.TOOL_CALLS).toLowerCase()))) {
194+
var toolCallConversation = handleToolCalls(prompt, chatResponse);
195+
// Recursively call the call method with the tool call message
196+
// conversation that contains the call responses.
197+
return this.stream(new Prompt(toolCallConversation, prompt.getOptions()));
212198
}
213199

214-
return Mono.just(chatCompletion).flatMapIterable(ChatCompletions::getChoices).map(choice -> {
215-
var content = Optional.ofNullable(choice.getMessage()).orElse(choice.getDelta()).getContent();
216-
var generation = new Generation(content).withGenerationMetadata(generateChoiceMetadata(choice));
217-
return new ChatResponse(List.of(generation));
218-
});
200+
return Mono.just(chatResponse);
219201
});
220202
}
221203

222-
private List<Message> handleToolCallRequests(List<Message> previousMessages, ChatCompletions chatCompletion) {
204+
private ChatResponse toChatResponse(ChatCompletions chatCompletions) {
205+
206+
List<Generation> generations = nullSafeList(chatCompletions.getChoices()).stream().map(choice -> {
207+
// @formatter:off
208+
Map<String, Object> metadata = Map.of(
209+
"id", chatCompletions.getId() != null ? chatCompletions.getId() : "",
210+
"choiceIndex", choice.getIndex(),
211+
"finishReason", choice.getFinishReason() != null ? String.valueOf(choice.getFinishReason()) : "");
212+
// @formatter:on
213+
return buildGeneration(choice, metadata);
214+
}).toList();
223215

224-
ChatRequestAssistantMessage nativeAssistantMessage = this.extractAssistantMessage(chatCompletion);
216+
PromptMetadata promptFilterMetadata = generatePromptMetadata(chatCompletions);
225217

226-
List<AssistantMessage.ToolCall> assistantToolCalls = nativeAssistantMessage.getToolCalls()
227-
.stream()
228-
.map(tc -> (ChatCompletionsFunctionToolCall) tc)
229-
.map(toolCall -> new AssistantMessage.ToolCall(toolCall.getId(), toolCall.getType(),
230-
toolCall.getFunction().getName(), toolCall.getFunction().getArguments()))
231-
.toList();
218+
return new ChatResponse(generations, from(chatCompletions, promptFilterMetadata));
219+
}
220+
221+
private Generation buildGeneration(ChatChoice choice, Map<String, Object> metadata) {
232222

233-
AssistantMessage assistantMessage = new AssistantMessage(nativeAssistantMessage.getContent(), Map.of(),
234-
assistantToolCalls);
223+
var responseMessage = Optional.ofNullable(choice.getMessage()).orElse(choice.getDelta());
235224

236-
ToolResponseMessage toolResponseMessage = this.executeFunctions(assistantMessage);
225+
List<AssistantMessage.ToolCall> toolCalls = responseMessage.getToolCalls() == null ? List.of()
226+
: responseMessage.getToolCalls().stream().map(toolCall -> {
227+
final var tc1 = (ChatCompletionsFunctionToolCall) toolCall;
228+
String id = tc1.getId();
229+
String name = tc1.getFunction().getName();
230+
String arguments = tc1.getFunction().getArguments();
231+
return new AssistantMessage.ToolCall(id, "function", name, arguments);
232+
}).toList();
237233

238-
// History
239-
List<Message> messages = new ArrayList<>(previousMessages);
240-
messages.add(assistantMessage);
241-
messages.add(toolResponseMessage);
234+
var assistantMessage = new AssistantMessage(responseMessage.getContent(), metadata, toolCalls);
235+
var generationMetadata = generateChoiceMetadata(choice);
242236

243-
return messages;
237+
return new Generation(assistantMessage, generationMetadata);
244238
}
245239

246-
private ChatRequestAssistantMessage extractAssistantMessage(ChatCompletions response) {
247-
final var accessibleChatChoice = response.getChoices().get(0);
248-
var responseMessage = Optional.ofNullable(accessibleChatChoice.getMessage())
249-
.orElse(accessibleChatChoice.getDelta());
250-
ChatRequestAssistantMessage assistantMessage = new ChatRequestAssistantMessage("");
251-
final var toolCalls = responseMessage.getToolCalls();
252-
assistantMessage.setToolCalls(toolCalls.stream().map(tc -> {
253-
final var tc1 = (ChatCompletionsFunctionToolCall) tc;
254-
var toDowncast = new ChatCompletionsFunctionToolCall(tc.getId(),
255-
new FunctionCall(tc1.getFunction().getName(), tc1.getFunction().getArguments()));
256-
return ((ChatCompletionsToolCall) toDowncast);
257-
}).toList());
258-
return assistantMessage;
240+
public static ChatResponseMetadata from(ChatCompletions chatCompletions, PromptMetadata promptFilterMetadata) {
241+
Assert.notNull(chatCompletions, "Azure OpenAI ChatCompletions must not be null");
242+
String id = chatCompletions.getId();
243+
Usage usage = (chatCompletions.getUsage() != null) ? AzureOpenAiUsage.from(chatCompletions) : new EmptyUsage();
244+
ChatResponseMetadata chatResponseMetadata = ChatResponseMetadata.builder()
245+
.withId(id)
246+
.withUsage(usage)
247+
.withModel(chatCompletions.getModel())
248+
.withPromptMetadata(promptFilterMetadata)
249+
.withKeyValue("system-fingerprint", chatCompletions.getSystemFingerprint())
250+
.build();
251+
252+
return chatResponseMetadata;
259253
}
260254

261255
/**
@@ -560,21 +554,6 @@ private ChatCompletionsOptions copy(ChatCompletionsOptions fromOptions) {
560554
return copyOptions;
561555
}
562556

563-
protected boolean isToolFunctionCall(ChatCompletions chatCompletions) {
564-
565-
if (chatCompletions == null || CollectionUtils.isEmpty(chatCompletions.getChoices())) {
566-
return false;
567-
}
568-
569-
var choice = chatCompletions.getChoices().get(0);
570-
571-
if (choice == null || choice.getFinishReason() == null) {
572-
return false;
573-
}
574-
575-
return choice.getFinishReason() == CompletionsFinishReason.TOOL_CALLS;
576-
}
577-
578557
/**
579558
* Maps the SpringAI response format to the Azure response format
580559
* @param responseFormat SpringAI response format

models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ public static String getDeploymentName() {
206206
return deploymentName;
207207
}
208208
else {
209-
return "gpt-4-0125-preview";
209+
return "gpt-4o";
210210
}
211211
}
212212

spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/DeploymentNameUtil.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ public static String getDeploymentName() {
1010
return deploymentName;
1111
}
1212
else {
13-
return "gpt-4-0125-preview";
13+
return "gpt-4o";
1414
}
1515
}
1616

0 commit comments

Comments
 (0)