Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
import com.azure.ai.openai.models.ChatRequestToolMessage;
import com.azure.ai.openai.models.ChatRequestUserMessage;
import com.azure.ai.openai.models.CompletionsFinishReason;
import com.azure.ai.openai.models.CompletionsUsage;
import com.azure.ai.openai.models.ContentFilterResultsForPrompt;
import com.azure.ai.openai.models.FunctionCall;
import com.azure.core.util.BinaryData;
Expand All @@ -70,6 +71,7 @@
import org.springframework.ai.chat.metadata.PromptMetadata;
import org.springframework.ai.chat.metadata.PromptMetadata.PromptFilterMetadata;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.chat.metadata.UsageUtils;
import org.springframework.ai.chat.model.AbstractToolCallSupport;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
Expand Down Expand Up @@ -105,6 +107,7 @@
* @author timostark
* @author Soby Chacko
* @author Jihoon Kim
* @author Ilayaperumal Gopinathan
* @see ChatModel
* @see com.azure.ai.openai.OpenAIClient
* @since 1.0.0
Expand Down Expand Up @@ -176,10 +179,10 @@ public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, AzureOpenAi
this.observationRegistry = observationRegistry;
}

public static ChatResponseMetadata from(ChatCompletions chatCompletions, PromptMetadata promptFilterMetadata) {
public static ChatResponseMetadata from(ChatCompletions chatCompletions, PromptMetadata promptFilterMetadata,
Usage usage) {
Assert.notNull(chatCompletions, "Azure OpenAI ChatCompletions must not be null");
String id = chatCompletions.getId();
Usage usage = (chatCompletions.getUsage() != null) ? AzureOpenAiUsage.from(chatCompletions) : new EmptyUsage();
return ChatResponseMetadata.builder()
.withId(id)
.withUsage(usage)
Expand All @@ -189,12 +192,40 @@ public static ChatResponseMetadata from(ChatCompletions chatCompletions, PromptM
.build();
}

public static ChatResponseMetadata from(ChatCompletions chatCompletions, PromptMetadata promptFilterMetadata) {
Usage usage = (chatCompletions.getUsage() != null) ? AzureOpenAiUsage.from(chatCompletions) : new EmptyUsage();
return from(chatCompletions, promptFilterMetadata, usage);
}

public static ChatResponseMetadata from(ChatCompletions chatCompletions, PromptMetadata promptFilterMetadata,
CompletionsUsage usage) {
return from(chatCompletions, promptFilterMetadata, AzureOpenAiUsage.from(usage));
}

public static ChatResponseMetadata from(ChatResponse chatResponse, Usage usage) {
Assert.notNull(chatResponse, "ChatResponse must not be null");
ChatResponseMetadata chatResponseMetadata = chatResponse.getMetadata();
ChatResponseMetadata.Builder builder = ChatResponseMetadata.builder();
builder.withId(chatResponseMetadata.getId())
.withUsage(usage)
.withModel(chatResponseMetadata.getModel())
.withPromptMetadata(chatResponseMetadata.getPromptMetadata());
if (chatResponseMetadata.containsKey("system-fingerprint")) {
builder.withKeyValue("system-fingerprint", chatResponseMetadata.get("system-fingerprint"));
}
return builder.build();
}

public AzureOpenAiChatOptions getDefaultOptions() {
return AzureOpenAiChatOptions.fromOptions(this.defaultOptions);
}

@Override
public ChatResponse call(Prompt prompt) {
return this.internalCall(prompt, null);
}

public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) {

ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
.prompt(prompt)
Expand All @@ -210,7 +241,7 @@ public ChatResponse call(Prompt prompt) {
ChatCompletionsOptionsAccessHelper.setStream(options, false);

ChatCompletions chatCompletions = this.openAIClient.getChatCompletions(options.getModel(), options);
ChatResponse chatResponse = toChatResponse(chatCompletions);
ChatResponse chatResponse = toChatResponse(chatCompletions, previousChatResponse);
observationContext.setResponse(chatResponse);
return chatResponse;
});
Expand All @@ -220,14 +251,18 @@ && isToolCall(response, Set.of(String.valueOf(CompletionsFinishReason.TOOL_CALLS
var toolCallConversation = handleToolCalls(prompt, response);
// Recursively call the call method with the tool call message
// conversation that contains the call responses.
return this.call(new Prompt(toolCallConversation, prompt.getOptions()));
return this.internalCall(new Prompt(toolCallConversation, prompt.getOptions()), response);
}

return response;
}

@Override
public Flux<ChatResponse> stream(Prompt prompt) {
return this.internalStream(prompt, null);
}

public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousChatResponse) {

return Flux.deferContextual(contextView -> {
ChatCompletionsOptions options = toAzureChatCompletionsOptions(prompt);
Expand Down Expand Up @@ -279,16 +314,36 @@ public Flux<ChatResponse> stream(Prompt prompt) {
})
.flatMap(mono -> mono);

return accessibleChatCompletionsFlux.switchMap(chatCompletions -> {

ChatResponse chatResponse = toChatResponse(chatCompletions);
final Flux<ChatResponse> chatResponseFlux = accessibleChatCompletionsFlux.map(chatCompletion -> {
if (previousChatResponse == null) {
return toChatResponse(chatCompletion);
}
// Accumulate the usage from the previous chat response
CompletionsUsage usage = chatCompletion.getUsage();
Usage currentChatResponseUsage = usage != null ? AzureOpenAiUsage.from(usage) : new EmptyUsage();
Usage accumulatedUsage = UsageUtils.getCumulativeUsage(currentChatResponseUsage, previousChatResponse);
return toChatResponse(chatCompletion, accumulatedUsage);
}).buffer(2, 1).map(bufferList -> {
ChatResponse chatResponse1 = bufferList.get(0);
if (options.getStreamOptions() != null && options.getStreamOptions().isIncludeUsage()) {
if (bufferList.size() == 2) {
ChatResponse chatResponse2 = bufferList.get(1);
if (chatResponse2 != null && chatResponse2.getMetadata() != null
&& !UsageUtils.isEmpty(chatResponse2.getMetadata().getUsage())) {
return toChatResponse(chatResponse1, chatResponse2.getMetadata().getUsage());
}
}
}
return chatResponse1;
});

return chatResponseFlux.flatMap(chatResponse -> {
if (!isProxyToolCalls(prompt, this.defaultOptions) && isToolCall(chatResponse,
Set.of(String.valueOf(CompletionsFinishReason.TOOL_CALLS).toLowerCase()))) {
var toolCallConversation = handleToolCalls(prompt, chatResponse);
// Recursively call the call method with the tool call message
// conversation that contains the call responses.
return this.stream(new Prompt(toolCallConversation, prompt.getOptions()));
return this.internalStream(new Prompt(toolCallConversation, prompt.getOptions()), chatResponse);
}

Flux<ChatResponse> flux = Flux.just(chatResponse)
Expand All @@ -305,6 +360,44 @@ public Flux<ChatResponse> stream(Prompt prompt) {

private ChatResponse toChatResponse(ChatCompletions chatCompletions) {

List<Generation> generations = nullSafeList(chatCompletions.getChoices()).stream().map(choice -> {
// @formatter:off
Map<String, Object> metadata = Map.of(
"id", chatCompletions.getId() != null ? chatCompletions.getId() : "",
"choiceIndex", choice.getIndex(),
"finishReason", choice.getFinishReason() != null ? String.valueOf(choice.getFinishReason()) : "");
// @formatter:on
return buildGeneration(choice, metadata);
}).toList();

PromptMetadata promptFilterMetadata = generatePromptMetadata(chatCompletions);

return new ChatResponse(generations, from(chatCompletions, promptFilterMetadata));
}

private ChatResponse toChatResponse(ChatCompletions chatCompletions, Usage usage) {

List<Generation> generations = nullSafeList(chatCompletions.getChoices()).stream().map(choice -> {
// @formatter:off
Map<String, Object> metadata = Map.of(
"id", chatCompletions.getId() != null ? chatCompletions.getId() : "",
"choiceIndex", choice.getIndex(),
"finishReason", choice.getFinishReason() != null ? String.valueOf(choice.getFinishReason()) : "");
// @formatter:on
return buildGeneration(choice, metadata);
}).toList();

PromptMetadata promptFilterMetadata = generatePromptMetadata(chatCompletions);

return new ChatResponse(generations, from(chatCompletions, promptFilterMetadata, usage));
}

private ChatResponse toChatResponse(ChatResponse chatResponse, Usage usage) {
return new ChatResponse(chatResponse.getResults(), from(chatResponse, usage));
}

private ChatResponse toChatResponse(ChatCompletions chatCompletions, ChatResponse previousChatResponse) {

List<Generation> generations = nullSafeList(chatCompletions.getChoices()).stream().map(choice -> {
// @formatter:off
Map<String, Object> metadata = Map.of(
Expand All @@ -316,8 +409,12 @@ private ChatResponse toChatResponse(ChatCompletions chatCompletions) {
}).toList();

PromptMetadata promptFilterMetadata = generatePromptMetadata(chatCompletions);

return new ChatResponse(generations, from(chatCompletions, promptFilterMetadata));
Usage currentUsage = null;
if (chatCompletions.getUsage() != null) {
currentUsage = AzureOpenAiUsage.from(chatCompletions);
}
Usage cumulativeUsage = UsageUtils.getCumulativeUsage(currentUsage, previousChatResponse);
return new ChatResponse(generations, from(chatCompletions, promptFilterMetadata, cumulativeUsage));
}

private Generation buildGeneration(ChatChoice choice, Map<String, Object> metadata) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.util.stream.Collectors;

import com.azure.ai.openai.OpenAIClientBuilder;
import com.azure.ai.openai.models.ChatCompletionStreamOptions;
import com.azure.core.credential.AzureKeyCredential;
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
Expand Down Expand Up @@ -80,7 +81,12 @@ void functionCallTest() {

logger.info("Response: {}", response);

assertThat(response.getResult()).isNotNull();
assertThat(response.getResult().getOutput()).isNotNull();
assertThat(response.getResult().getOutput().getText()).contains("30", "10", "15");
assertThat(response.getMetadata()).isNotNull();
assertThat(response.getMetadata().getUsage()).isNotNull();
assertThat(response.getMetadata().getUsage().getTotalTokens()).isGreaterThan(600).isLessThan(800);
}

@Test
Expand Down Expand Up @@ -142,6 +148,34 @@ void streamFunctionCallTest() {

}

@Test
void streamFunctionCallUsageTest() {
UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?");

List<Message> messages = new ArrayList<>(List.of(userMessage));

ChatCompletionStreamOptions streamOptions = new ChatCompletionStreamOptions();
streamOptions.setIncludeUsage(true);

var promptOptions = AzureOpenAiChatOptions.builder()
.withDeploymentName(this.selectedModel)
.withFunctionCallbacks(List.of(FunctionCallback.builder()
.function("getCurrentWeather", new MockWeatherService())
.description("Get the current weather in a given location")
.inputType(MockWeatherService.Request.class)
.build()))
.withStreamOptions(streamOptions)
.build();

Flux<ChatResponse> response = this.chatModel.stream(new Prompt(messages, promptOptions));

ChatResponse chatResponse = response.last().block();
logger.info("Response: {}", chatResponse);

assertThat(chatResponse.getMetadata().getUsage().getTotalTokens()).isGreaterThan(600).isLessThan(800);

}

@Test
void functionCallSequentialAndStreamTest() {

Expand Down
Loading