Skip to content
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,18 @@
import com.azure.core.util.IterableStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import reactor.core.publisher.Flux;

import org.springframework.ai.azure.openai.metadata.AzureOpenAiGenerationMetadata;
import org.springframework.ai.azure.openai.metadata.AzureOpenAiChatResponseMetadata;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.metadata.ChoiceMetadata;
import org.springframework.ai.metadata.PromptMetadata;
import org.springframework.ai.metadata.PromptMetadata.PromptFilterMetadata;
import org.springframework.ai.prompt.Prompt;
import org.springframework.ai.prompt.messages.Message;
import org.springframework.ai.chat.metadata.PromptMetadata;
import org.springframework.ai.chat.metadata.PromptMetadata.PromptFilterMetadata;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.messages.Message;
import org.springframework.util.Assert;

/**
Expand Down Expand Up @@ -134,7 +134,7 @@ public AzureOpenAiChatClient withMaxTokens(Integer maxTokens) {
}

@Override
public String generate(String text) {
public String call(String text) {

ChatRequestMessage azureChatMessage = new ChatRequestUserMessage(text);

Expand All @@ -160,7 +160,7 @@ public String generate(String text) {
}

@Override
public ChatResponse generate(Prompt prompt) {
public ChatResponse call(Prompt prompt) {

ChatCompletionsOptions options = toAzureChatCompletionsOptions(prompt);
options.setStream(false);
Expand All @@ -174,11 +174,12 @@ public ChatResponse generate(Prompt prompt) {
List<Generation> generations = chatCompletions.getChoices()
.stream()
.map(choice -> new Generation(choice.getMessage().getContent())
.withChoiceMetadata(generateChoiceMetadata(choice)))
.withGenerationMetadata(generateChoiceMetadata(choice)))
.toList();

return new ChatResponse(generations, AzureOpenAiGenerationMetadata.from(chatCompletions))
.withPromptMetadata(generatePromptMetadata(chatCompletions));
PromptMetadata promptFilterMetadata = generatePromptMetadata(chatCompletions);
return new ChatResponse(generations,
AzureOpenAiChatResponseMetadata.from(chatCompletions, promptFilterMetadata));
}

@Override
Expand All @@ -199,14 +200,17 @@ public Flux<ChatResponse> generateStream(Prompt prompt) {
.flatMap(List::stream)
.map(choice -> {
var content = (choice.getDelta() != null) ? choice.getDelta().getContent() : null;
var generation = new Generation(content).withChoiceMetadata(generateChoiceMetadata(choice));
var generation = new Generation(content).withGenerationMetadata(generateChoiceMetadata(choice));
return new ChatResponse(List.of(generation));
}));
}

private ChatCompletionsOptions toAzureChatCompletionsOptions(Prompt prompt) {

List<ChatRequestMessage> azureMessages = prompt.getMessages().stream().map(this::fromSpringAiMessage).toList();
List<ChatRequestMessage> azureMessages = prompt.getInstructions()
.stream()
.map(this::fromSpringAiMessage)
.toList();

ChatCompletionsOptions options = new ChatCompletionsOptions(azureMessages);

Expand All @@ -233,8 +237,8 @@ private ChatRequestMessage fromSpringAiMessage(Message message) {

}

private ChoiceMetadata generateChoiceMetadata(ChatChoice choice) {
return ChoiceMetadata.from(String.valueOf(choice.getFinishReason()), choice.getContentFilterResults());
private ChatGenerationMetadata generateChoiceMetadata(ChatChoice choice) {
return ChatGenerationMetadata.from(String.valueOf(choice.getFinishReason()), choice.getContentFilterResults());
}

private PromptMetadata generatePromptMetadata(ChatCompletions chatCompletions) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,38 +18,44 @@

import com.azure.ai.openai.models.ChatCompletions;

import org.springframework.ai.metadata.GenerationMetadata;
import org.springframework.ai.metadata.Usage;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.PromptMetadata;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.util.Assert;

/**
* {@link GenerationMetadata} implementation for
* {@link ChatResponseMetadata} implementation for
* {@literal Microsoft Azure OpenAI Service}.
*
* @author John Blum
* @see org.springframework.ai.metadata.GenerationMetadata
* @see ChatResponseMetadata
* @since 0.7.1
*/
public class AzureOpenAiGenerationMetadata implements GenerationMetadata {
public class AzureOpenAiChatResponseMetadata implements ChatResponseMetadata {

protected static final String AI_METADATA_STRING = "{ @type: %1$s, id: %2$s, usage: %3$s, rateLimit: %4$s }";

@SuppressWarnings("all")
public static AzureOpenAiGenerationMetadata from(ChatCompletions chatCompletions) {
public static AzureOpenAiChatResponseMetadata from(ChatCompletions chatCompletions,
PromptMetadata promptFilterMetadata) {
Assert.notNull(chatCompletions, "Azure OpenAI ChatCompletions must not be null");
String id = chatCompletions.getId();
AzureOpenAiUsage usage = AzureOpenAiUsage.from(chatCompletions);
AzureOpenAiGenerationMetadata generationMetadata = new AzureOpenAiGenerationMetadata(id, usage);
return generationMetadata;
AzureOpenAiChatResponseMetadata chatResponseMetadata = new AzureOpenAiChatResponseMetadata(id, usage,
promptFilterMetadata);
return chatResponseMetadata;
}

private final String id;

private final Usage usage;

protected AzureOpenAiGenerationMetadata(String id, AzureOpenAiUsage usage) {
private final PromptMetadata promptMetadata;

protected AzureOpenAiChatResponseMetadata(String id, AzureOpenAiUsage usage, PromptMetadata promptMetadata) {
this.id = id;
this.usage = usage;
this.promptMetadata = promptMetadata;
}

public String getId() {
Expand All @@ -61,6 +67,11 @@ public Usage getUsage() {
return this.usage;
}

@Override
public PromptMetadata getPromptMetadata() {
return this.promptMetadata;
}

@Override
public String toString() {
return AI_METADATA_STRING.formatted(getClass().getTypeName(), getId(), getUsage(), getRateLimit());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import com.azure.ai.openai.models.ChatCompletions;
import com.azure.ai.openai.models.CompletionsUsage;

import org.springframework.ai.metadata.Usage;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.util.Assert;

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@

import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.parser.BeanOutputParser;
import org.springframework.ai.parser.ListOutputParser;
import org.springframework.ai.parser.MapOutputParser;
import org.springframework.ai.prompt.Prompt;
import org.springframework.ai.prompt.PromptTemplate;
import org.springframework.ai.prompt.SystemPromptTemplate;
import org.springframework.ai.prompt.messages.Message;
import org.springframework.ai.prompt.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
Expand Down Expand Up @@ -53,8 +54,8 @@ void roleTest() {
UserMessage userMessage = new UserMessage("Generate the names of 5 famous pirates.");

Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
ChatResponse response = chatClient.generate(prompt);
assertThat(response.getGeneration().getContent()).contains("Blackbeard");
ChatResponse response = chatClient.call(prompt);
assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard");
}

@Test
Expand All @@ -70,9 +71,9 @@ void outputParser() {
PromptTemplate promptTemplate = new PromptTemplate(template,
Map.of("subject", "ice cream flavors", "format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
Generation generation = chatClient.generate(prompt).getGeneration();
Generation generation = chatClient.call(prompt).getResult();

List<String> list = outputParser.parse(generation.getContent());
List<String> list = outputParser.parse(generation.getOutput().getContent());
assertThat(list).hasSize(5);

}
Expand All @@ -89,9 +90,9 @@ void mapOutputParser() {
PromptTemplate promptTemplate = new PromptTemplate(template,
Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
Generation generation = chatClient.generate(prompt).getGeneration();
Generation generation = chatClient.call(prompt).getResult();

Map<String, Object> result = outputParser.parse(generation.getContent());
Map<String, Object> result = outputParser.parse(generation.getOutput().getContent());
assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));

}
Expand All @@ -108,9 +109,9 @@ void beanOutputParser() {
""";
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
Generation generation = chatClient.generate(prompt).getGeneration();
Generation generation = chatClient.call(prompt).getResult();

ActorsFilms actorsFilms = outputParser.parse(generation.getContent());
ActorsFilms actorsFilms = outputParser.parse(generation.getOutput().getContent());
assertThat(actorsFilms.actor()).isNotNull();
}

Expand All @@ -129,9 +130,9 @@ void beanOutputParserRecords() {
""";
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
Generation generation = chatClient.generate(prompt).getGeneration();
Generation generation = chatClient.call(prompt).getResult();

ActorsFilmsRecord actorsFilms = outputParser.parse(generation.getContent());
ActorsFilmsRecord actorsFilms = outputParser.parse(generation.getOutput().getContent());
System.out.println(actorsFilms);
assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks");
assertThat(actorsFilms.movies()).hasSize(5);
Expand All @@ -154,9 +155,10 @@ void beanStreamOutputParserRecords() {
.collectList()
.block()
.stream()
.map(ChatResponse::getGenerations)
.map(ChatResponse::getResults)
.flatMap(List::stream)
.map(Generation::getContent)
.map(Generation::getOutput)
.map(AssistantMessage::getContent)
.filter(Objects::nonNull)
.collect(Collectors.joining());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,13 @@
import org.springframework.ai.azure.openai.MockAzureOpenAiTestConfiguration;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.metadata.ChoiceMetadata;
import org.springframework.ai.metadata.GenerationMetadata;
import org.springframework.ai.metadata.PromptMetadata;
import org.springframework.ai.metadata.RateLimit;
import org.springframework.ai.metadata.Usage;
import org.springframework.ai.prompt.Prompt;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.PromptMetadata;
import org.springframework.ai.chat.metadata.RateLimit;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
Expand Down Expand Up @@ -75,14 +76,15 @@ void azureOpenAiMetadataCapturedDuringGeneration() {

Prompt prompt = new Prompt("Can I fly like a bird?");

ChatResponse response = this.aiClient.generate(prompt);
ChatResponse response = this.aiClient.call(prompt);

assertThat(response).isNotNull();

Generation generation = response.getGeneration();
Generation generation = response.getResult();

assertThat(generation).isNotNull()
.extracting(Generation::getContent)
.extracting(Generation::getOutput)
.extracting(AssistantMessage::getContent)
.isEqualTo("No! You will actually land with a resounding thud. This is the way!");

assertPromptMetadata(response);
Expand All @@ -92,7 +94,7 @@ void azureOpenAiMetadataCapturedDuringGeneration() {

private void assertPromptMetadata(ChatResponse response) {

PromptMetadata promptMetadata = response.getPromptMetadata();
PromptMetadata promptMetadata = response.getMetadata().getPromptMetadata();

assertThat(promptMetadata).isNotNull();

Expand All @@ -106,12 +108,12 @@ private void assertPromptMetadata(ChatResponse response) {

private void assertGenerationMetadata(ChatResponse response) {

GenerationMetadata generationMetadata = response.getGenerationMetadata();
ChatResponseMetadata chatResponseMetadata = response.getMetadata();

assertThat(generationMetadata).isNotNull();
assertThat(generationMetadata.getRateLimit()).isEqualTo(RateLimit.NULL);
assertThat(chatResponseMetadata).isNotNull();
assertThat(chatResponseMetadata.getRateLimit()).isEqualTo(RateLimit.NULL);

Usage usage = generationMetadata.getUsage();
Usage usage = chatResponseMetadata.getUsage();

assertThat(usage).isNotNull();
assertThat(usage).isNotEqualTo(Usage.NULL);
Expand All @@ -122,11 +124,11 @@ private void assertGenerationMetadata(ChatResponse response) {

private void assertChoiceMetadata(Generation generation) {

ChoiceMetadata choiceMetadata = generation.getChoiceMetadata();
ChatGenerationMetadata chatGenerationMetadata = generation.getMetadata();

assertThat(choiceMetadata).isNotNull();
assertThat(choiceMetadata.getFinishReason()).isEqualTo("stop");
assertContentFilterResults(choiceMetadata.getContentFilterMetadata());
assertThat(chatGenerationMetadata).isNotNull();
assertThat(chatGenerationMetadata.getFinishReason()).isEqualTo("stop");
assertContentFilterResults(chatGenerationMetadata.getContentFilterMetadata());
}

private void assertContentFilterResultsForPrompt(ContentFilterResultDetailsForPrompt contentFilterResultForPrompt,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package org.springframework.ai.bedrock;

import org.springframework.ai.bedrock.api.AbstractBedrockApi.AmazonBedrockInvocationMetrics;
import org.springframework.ai.metadata.Usage;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.util.Assert;

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
import java.util.List;
import java.util.stream.Collectors;

import org.springframework.ai.prompt.messages.Message;
import org.springframework.ai.prompt.messages.MessageType;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;

/**
* Converts a list of messages to a prompt for bedrock models.
Expand Down
Loading