Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,37 @@
*/
package org.springframework.ai.azure.openai;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;

import org.springframework.ai.azure.openai.metadata.AzureOpenAiUsage;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.EmptyUsage;
import org.springframework.ai.chat.metadata.PromptMetadata;
import org.springframework.ai.chat.metadata.PromptMetadata.PromptFilterMetadata;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.chat.model.AbstractToolCallSupport;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;

import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.models.ChatChoice;
import com.azure.ai.openai.models.ChatCompletions;
Expand All @@ -41,37 +72,10 @@
import com.azure.ai.openai.models.FunctionDefinition;
import com.azure.core.util.BinaryData;
import com.azure.core.util.IterableStream;
import org.springframework.ai.azure.openai.metadata.AzureOpenAiUsage;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.PromptMetadata;
import org.springframework.ai.chat.metadata.PromptMetadata.PromptFilterMetadata;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.chat.model.AbstractToolCallSupport;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;

import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;

/**
* {@link ChatModel} implementation for {@literal Microsoft Azure AI} backed by
* {@link OpenAIClient}.
Expand Down Expand Up @@ -136,37 +140,16 @@ public ChatResponse call(Prompt prompt) {

ChatCompletions chatCompletions = this.openAIClient.getChatCompletions(options.getModel(), options);

if (isToolFunctionCall(chatCompletions)) {
List<Message> toolCallMessageConversation = this.handleToolCallRequests(prompt.getInstructions(),
chatCompletions);
ChatResponse chatResponse = toChatResponse(chatCompletions);

if (isToolCall(chatResponse, Set.of(String.valueOf(CompletionsFinishReason.TOOL_CALLS).toLowerCase()))) {
var toolCallConversation = handleToolCalls(prompt, chatResponse);
// Recursively call the call method with the tool call message
// conversation that contains the call responses.
return this.call(new Prompt(toolCallMessageConversation, prompt.getOptions()));
return this.call(new Prompt(toolCallConversation, prompt.getOptions()));
}

List<Generation> generations = nullSafeList(chatCompletions.getChoices()).stream()
.map(choice -> new Generation(choice.getMessage().getContent())
.withGenerationMetadata(generateChoiceMetadata(choice)))
.toList();

PromptMetadata promptFilterMetadata = generatePromptMetadata(chatCompletions);

return new ChatResponse(generations, from(chatCompletions, promptFilterMetadata));
}

public static ChatResponseMetadata from(ChatCompletions chatCompletions, PromptMetadata promptFilterMetadata) {
Assert.notNull(chatCompletions, "Azure OpenAI ChatCompletions must not be null");
String id = chatCompletions.getId();
AzureOpenAiUsage usage = AzureOpenAiUsage.from(chatCompletions);
ChatResponseMetadata chatResponseMetadata = ChatResponseMetadata.builder()
.withId(id)
.withUsage(usage)
.withModel(chatCompletions.getModel())
.withPromptMetadata(promptFilterMetadata)
.withKeyValue("system-fingerprint", chatCompletions.getSystemFingerprint())
.build();

return chatResponseMetadata;
return chatResponse;
}

@Override
Expand All @@ -179,10 +162,9 @@ public Flux<ChatResponse> stream(Prompt prompt) {
.getChatCompletionsStream(options.getModel(), options);

final var isFunctionCall = new AtomicBoolean(false);
final var accessibleChatCompletionsFlux = Flux.fromIterable(chatCompletionsStream)
final Flux<ChatCompletions> accessibleChatCompletionsFlux = Flux.fromIterable(chatCompletionsStream)
// Note: the first chat completions can be ignored when using Azure OpenAI
// service which is a known service bug.
// .skip(1)
.filter(chatCompletions -> !CollectionUtils.isEmpty(chatCompletions.getChoices()))
.map(chatCompletions -> {
final var toolCalls = chatCompletions.getChoices().get(0).getDelta().getToolCalls();
Expand All @@ -204,58 +186,70 @@ public Flux<ChatResponse> stream(Prompt prompt) {
})
.flatMap(mono -> mono);

return accessibleChatCompletionsFlux.switchMap(chatCompletion -> {
if (isToolFunctionCall(chatCompletion)) {
List<Message> toolCallMessageConversation = this.handleToolCallRequests(prompt.getInstructions(),
chatCompletion);
return this.stream(new Prompt(toolCallMessageConversation, prompt.getOptions()));
return accessibleChatCompletionsFlux.switchMap(chatCompletions -> {

ChatResponse chatResponse = toChatResponse(chatCompletions);

if (isToolCall(chatResponse, Set.of(String.valueOf(CompletionsFinishReason.TOOL_CALLS).toLowerCase()))) {
var toolCallConversation = handleToolCalls(prompt, chatResponse);
// Recursively call the call method with the tool call message
// conversation that contains the call responses.
return this.stream(new Prompt(toolCallConversation, prompt.getOptions()));
}

return Mono.just(chatCompletion).flatMapIterable(ChatCompletions::getChoices).map(choice -> {
var content = Optional.ofNullable(choice.getMessage()).orElse(choice.getDelta()).getContent();
var generation = new Generation(content).withGenerationMetadata(generateChoiceMetadata(choice));
return new ChatResponse(List.of(generation));
});
return Mono.just(chatResponse);
});
}

private List<Message> handleToolCallRequests(List<Message> previousMessages, ChatCompletions chatCompletion) {
private ChatResponse toChatResponse(ChatCompletions chatCompletions) {

List<Generation> generations = nullSafeList(chatCompletions.getChoices()).stream().map(choice -> {
// @formatter:off
Map<String, Object> metadata = Map.of(
"id", chatCompletions.getId() != null ? chatCompletions.getId() : "",
"choiceIndex", choice.getIndex(),
"finishReason", choice.getFinishReason() != null ? String.valueOf(choice.getFinishReason()) : "");
// @formatter:on
return buildGeneration(choice, metadata);
}).toList();

ChatRequestAssistantMessage nativeAssistantMessage = this.extractAssistantMessage(chatCompletion);
PromptMetadata promptFilterMetadata = generatePromptMetadata(chatCompletions);

List<AssistantMessage.ToolCall> assistantToolCalls = nativeAssistantMessage.getToolCalls()
.stream()
.map(tc -> (ChatCompletionsFunctionToolCall) tc)
.map(toolCall -> new AssistantMessage.ToolCall(toolCall.getId(), toolCall.getType(),
toolCall.getFunction().getName(), toolCall.getFunction().getArguments()))
.toList();
return new ChatResponse(generations, from(chatCompletions, promptFilterMetadata));
}

private Generation buildGeneration(ChatChoice choice, Map<String, Object> metadata) {

AssistantMessage assistantMessage = new AssistantMessage(nativeAssistantMessage.getContent(), Map.of(),
assistantToolCalls);
var responseMessage = Optional.ofNullable(choice.getMessage()).orElse(choice.getDelta());

ToolResponseMessage toolResponseMessage = this.executeFunctions(assistantMessage);
List<AssistantMessage.ToolCall> toolCalls = responseMessage.getToolCalls() == null ? List.of()
: responseMessage.getToolCalls().stream().map(toolCall -> {
final var tc1 = (ChatCompletionsFunctionToolCall) toolCall;
String id = tc1.getId();
String name = tc1.getFunction().getName();
String arguments = tc1.getFunction().getArguments();
return new AssistantMessage.ToolCall(id, "function", name, arguments);
}).toList();

// History
List<Message> messages = new ArrayList<>(previousMessages);
messages.add(assistantMessage);
messages.add(toolResponseMessage);
var assistantMessage = new AssistantMessage(responseMessage.getContent(), metadata, toolCalls);
var generationMetadata = generateChoiceMetadata(choice);

return messages;
return new Generation(assistantMessage, generationMetadata);
}

private ChatRequestAssistantMessage extractAssistantMessage(ChatCompletions response) {
final var accessibleChatChoice = response.getChoices().get(0);
var responseMessage = Optional.ofNullable(accessibleChatChoice.getMessage())
.orElse(accessibleChatChoice.getDelta());
ChatRequestAssistantMessage assistantMessage = new ChatRequestAssistantMessage("");
final var toolCalls = responseMessage.getToolCalls();
assistantMessage.setToolCalls(toolCalls.stream().map(tc -> {
final var tc1 = (ChatCompletionsFunctionToolCall) tc;
var toDowncast = new ChatCompletionsFunctionToolCall(tc.getId(),
new FunctionCall(tc1.getFunction().getName(), tc1.getFunction().getArguments()));
return ((ChatCompletionsToolCall) toDowncast);
}).toList());
return assistantMessage;
public static ChatResponseMetadata from(ChatCompletions chatCompletions, PromptMetadata promptFilterMetadata) {
Assert.notNull(chatCompletions, "Azure OpenAI ChatCompletions must not be null");
String id = chatCompletions.getId();
Usage usage = (chatCompletions.getUsage() != null) ? AzureOpenAiUsage.from(chatCompletions) : new EmptyUsage();
ChatResponseMetadata chatResponseMetadata = ChatResponseMetadata.builder()
.withId(id)
.withUsage(usage)
.withModel(chatCompletions.getModel())
.withPromptMetadata(promptFilterMetadata)
.withKeyValue("system-fingerprint", chatCompletions.getSystemFingerprint())
.build();

return chatResponseMetadata;
}

/**
Expand Down Expand Up @@ -560,21 +554,6 @@ private ChatCompletionsOptions copy(ChatCompletionsOptions fromOptions) {
return copyOptions;
}

protected boolean isToolFunctionCall(ChatCompletions chatCompletions) {

if (chatCompletions == null || CollectionUtils.isEmpty(chatCompletions.getChoices())) {
return false;
}

var choice = chatCompletions.getChoices().get(0);

if (choice == null || choice.getFinishReason() == null) {
return false;
}

return choice.getFinishReason() == CompletionsFinishReason.TOOL_CALLS;
}

/**
* Maps the SpringAI response format to the Azure response format
* @param responseFormat SpringAI response format
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ public static String getDeploymentName() {
return deploymentName;
}
else {
return "gpt-4-0125-preview";
return "gpt-4o";
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ public static String getDeploymentName() {
return deploymentName;
}
else {
return "gpt-4-0125-preview";
return "gpt-4o";
}
}

Expand Down