diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatClient.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatClient.java index 72253696b96..5decb221b0b 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatClient.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatClient.java @@ -32,18 +32,18 @@ import com.azure.core.util.IterableStream; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import reactor.core.publisher.Flux; -import org.springframework.ai.azure.openai.metadata.AzureOpenAiGenerationMetadata; +import org.springframework.ai.azure.openai.metadata.AzureOpenAiChatResponseMetadata; import org.springframework.ai.chat.ChatClient; import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.Generation; import org.springframework.ai.chat.StreamingChatClient; -import org.springframework.ai.metadata.ChoiceMetadata; -import org.springframework.ai.metadata.PromptMetadata; -import org.springframework.ai.metadata.PromptMetadata.PromptFilterMetadata; -import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.messages.Message; +import org.springframework.ai.chat.metadata.PromptMetadata; +import org.springframework.ai.chat.metadata.PromptMetadata.PromptFilterMetadata; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.messages.Message; import org.springframework.util.Assert; /** @@ -134,7 +134,7 @@ public AzureOpenAiChatClient withMaxTokens(Integer maxTokens) { } @Override - public String generate(String text) { + public String call(String text) { ChatRequestMessage azureChatMessage = new ChatRequestUserMessage(text); @@ -160,7 +160,7 @@ public String generate(String text) { } @Override - public ChatResponse generate(Prompt prompt) { + public ChatResponse call(Prompt prompt) { ChatCompletionsOptions options = toAzureChatCompletionsOptions(prompt); options.setStream(false); @@ -174,11 +174,12 @@ public ChatResponse generate(Prompt prompt) { List generations = chatCompletions.getChoices() .stream() .map(choice -> new Generation(choice.getMessage().getContent()) - .withChoiceMetadata(generateChoiceMetadata(choice))) + .withGenerationMetadata(generateChoiceMetadata(choice))) .toList(); - return new ChatResponse(generations, AzureOpenAiGenerationMetadata.from(chatCompletions)) - .withPromptMetadata(generatePromptMetadata(chatCompletions)); + PromptMetadata promptFilterMetadata = generatePromptMetadata(chatCompletions); + return new ChatResponse(generations, + AzureOpenAiChatResponseMetadata.from(chatCompletions, promptFilterMetadata)); } @Override @@ -199,14 +200,17 @@ public Flux generateStream(Prompt prompt) { .flatMap(List::stream) .map(choice -> { var content = (choice.getDelta() != null) ? choice.getDelta().getContent() : null; - var generation = new Generation(content).withChoiceMetadata(generateChoiceMetadata(choice)); + var generation = new Generation(content).withGenerationMetadata(generateChoiceMetadata(choice)); return new ChatResponse(List.of(generation)); })); } private ChatCompletionsOptions toAzureChatCompletionsOptions(Prompt prompt) { - List azureMessages = prompt.getMessages().stream().map(this::fromSpringAiMessage).toList(); + List azureMessages = prompt.getInstructions() + .stream() + .map(this::fromSpringAiMessage) + .toList(); ChatCompletionsOptions options = new ChatCompletionsOptions(azureMessages); @@ -233,8 +237,8 @@ private ChatRequestMessage fromSpringAiMessage(Message message) { } - private ChoiceMetadata generateChoiceMetadata(ChatChoice choice) { - return ChoiceMetadata.from(String.valueOf(choice.getFinishReason()), choice.getContentFilterResults()); + private ChatGenerationMetadata generateChoiceMetadata(ChatChoice choice) { + return ChatGenerationMetadata.from(String.valueOf(choice.getFinishReason()), choice.getContentFilterResults()); } private PromptMetadata generatePromptMetadata(ChatCompletions chatCompletions) { diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiGenerationMetadata.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiChatResponseMetadata.java similarity index 62% rename from models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiGenerationMetadata.java rename to models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiChatResponseMetadata.java index 51208526da8..4472a6df2c3 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiGenerationMetadata.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiChatResponseMetadata.java @@ -18,38 +18,44 @@ import com.azure.ai.openai.models.ChatCompletions; -import org.springframework.ai.metadata.GenerationMetadata; -import org.springframework.ai.metadata.Usage; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.metadata.PromptMetadata; +import org.springframework.ai.chat.metadata.Usage; import org.springframework.util.Assert; /** - * {@link GenerationMetadata} implementation for + * {@link ChatResponseMetadata} implementation for * {@literal Microsoft Azure OpenAI Service}. * * @author John Blum - * @see org.springframework.ai.metadata.GenerationMetadata + * @see ChatResponseMetadata * @since 0.7.1 */ -public class AzureOpenAiGenerationMetadata implements GenerationMetadata { +public class AzureOpenAiChatResponseMetadata implements ChatResponseMetadata { protected static final String AI_METADATA_STRING = "{ @type: %1$s, id: %2$s, usage: %3$s, rateLimit: %4$s }"; @SuppressWarnings("all") - public static AzureOpenAiGenerationMetadata from(ChatCompletions chatCompletions) { + public static AzureOpenAiChatResponseMetadata from(ChatCompletions chatCompletions, + PromptMetadata promptFilterMetadata) { Assert.notNull(chatCompletions, "Azure OpenAI ChatCompletions must not be null"); String id = chatCompletions.getId(); AzureOpenAiUsage usage = AzureOpenAiUsage.from(chatCompletions); - AzureOpenAiGenerationMetadata generationMetadata = new AzureOpenAiGenerationMetadata(id, usage); - return generationMetadata; + AzureOpenAiChatResponseMetadata chatResponseMetadata = new AzureOpenAiChatResponseMetadata(id, usage, + promptFilterMetadata); + return chatResponseMetadata; } private final String id; private final Usage usage; - protected AzureOpenAiGenerationMetadata(String id, AzureOpenAiUsage usage) { + private final PromptMetadata promptMetadata; + + protected AzureOpenAiChatResponseMetadata(String id, AzureOpenAiUsage usage, PromptMetadata promptMetadata) { this.id = id; this.usage = usage; + this.promptMetadata = promptMetadata; } public String getId() { @@ -61,6 +67,11 @@ public Usage getUsage() { return this.usage; } + @Override + public PromptMetadata getPromptMetadata() { + return this.promptMetadata; + } + @Override public String toString() { return AI_METADATA_STRING.formatted(getClass().getTypeName(), getId(), getUsage(), getRateLimit()); diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiUsage.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiUsage.java index 73bdc47722a..b5af6fe603b 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiUsage.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiUsage.java @@ -19,7 +19,7 @@ import com.azure.ai.openai.models.ChatCompletions; import com.azure.ai.openai.models.CompletionsUsage; -import org.springframework.ai.metadata.Usage; +import org.springframework.ai.chat.metadata.Usage; import org.springframework.util.Assert; /** diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatClientIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatClientIT.java index 141fc0fd8bc..d82db85c31e 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatClientIT.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatClientIT.java @@ -14,14 +14,15 @@ import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.Generation; +import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.parser.BeanOutputParser; import org.springframework.ai.parser.ListOutputParser; import org.springframework.ai.parser.MapOutputParser; -import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.PromptTemplate; -import org.springframework.ai.prompt.SystemPromptTemplate; -import org.springframework.ai.prompt.messages.Message; -import org.springframework.ai.prompt.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; @@ -53,8 +54,8 @@ void roleTest() { UserMessage userMessage = new UserMessage("Generate the names of 5 famous pirates."); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - ChatResponse response = chatClient.generate(prompt); - assertThat(response.getGeneration().getContent()).contains("Blackbeard"); + ChatResponse response = chatClient.call(prompt); + assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); } @Test @@ -70,9 +71,9 @@ void outputParser() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "ice cream flavors", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatClient.generate(prompt).getGeneration(); + Generation generation = chatClient.call(prompt).getResult(); - List list = outputParser.parse(generation.getContent()); + List list = outputParser.parse(generation.getOutput().getContent()); assertThat(list).hasSize(5); } @@ -89,9 +90,9 @@ void mapOutputParser() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatClient.generate(prompt).getGeneration(); + Generation generation = chatClient.call(prompt).getResult(); - Map result = outputParser.parse(generation.getContent()); + Map result = outputParser.parse(generation.getOutput().getContent()); assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); } @@ -108,9 +109,9 @@ void beanOutputParser() { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatClient.generate(prompt).getGeneration(); + Generation generation = chatClient.call(prompt).getResult(); - ActorsFilms actorsFilms = outputParser.parse(generation.getContent()); + ActorsFilms actorsFilms = outputParser.parse(generation.getOutput().getContent()); assertThat(actorsFilms.actor()).isNotNull(); } @@ -129,9 +130,9 @@ void beanOutputParserRecords() { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatClient.generate(prompt).getGeneration(); + Generation generation = chatClient.call(prompt).getResult(); - ActorsFilmsRecord actorsFilms = outputParser.parse(generation.getContent()); + ActorsFilmsRecord actorsFilms = outputParser.parse(generation.getOutput().getContent()); System.out.println(actorsFilms); assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); assertThat(actorsFilms.movies()).hasSize(5); @@ -154,9 +155,10 @@ void beanStreamOutputParserRecords() { .collectList() .block() .stream() - .map(ChatResponse::getGenerations) + .map(ChatResponse::getResults) .flatMap(List::stream) - .map(Generation::getContent) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) .filter(Objects::nonNull) .collect(Collectors.joining()); diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiChatClientMetadataTests.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiChatClientMetadataTests.java index 1a69a6bed6b..8974540dcd9 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiChatClientMetadataTests.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiChatClientMetadataTests.java @@ -28,12 +28,13 @@ import org.springframework.ai.azure.openai.MockAzureOpenAiTestConfiguration; import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.Generation; -import org.springframework.ai.metadata.ChoiceMetadata; -import org.springframework.ai.metadata.GenerationMetadata; -import org.springframework.ai.metadata.PromptMetadata; -import org.springframework.ai.metadata.RateLimit; -import org.springframework.ai.metadata.Usage; -import org.springframework.ai.prompt.Prompt; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.metadata.PromptMetadata; +import org.springframework.ai.chat.metadata.RateLimit; +import org.springframework.ai.chat.metadata.Usage; +import org.springframework.ai.chat.prompt.Prompt; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; @@ -75,14 +76,15 @@ void azureOpenAiMetadataCapturedDuringGeneration() { Prompt prompt = new Prompt("Can I fly like a bird?"); - ChatResponse response = this.aiClient.generate(prompt); + ChatResponse response = this.aiClient.call(prompt); assertThat(response).isNotNull(); - Generation generation = response.getGeneration(); + Generation generation = response.getResult(); assertThat(generation).isNotNull() - .extracting(Generation::getContent) + .extracting(Generation::getOutput) + .extracting(AssistantMessage::getContent) .isEqualTo("No! You will actually land with a resounding thud. This is the way!"); assertPromptMetadata(response); @@ -92,7 +94,7 @@ void azureOpenAiMetadataCapturedDuringGeneration() { private void assertPromptMetadata(ChatResponse response) { - PromptMetadata promptMetadata = response.getPromptMetadata(); + PromptMetadata promptMetadata = response.getMetadata().getPromptMetadata(); assertThat(promptMetadata).isNotNull(); @@ -106,12 +108,12 @@ private void assertPromptMetadata(ChatResponse response) { private void assertGenerationMetadata(ChatResponse response) { - GenerationMetadata generationMetadata = response.getGenerationMetadata(); + ChatResponseMetadata chatResponseMetadata = response.getMetadata(); - assertThat(generationMetadata).isNotNull(); - assertThat(generationMetadata.getRateLimit()).isEqualTo(RateLimit.NULL); + assertThat(chatResponseMetadata).isNotNull(); + assertThat(chatResponseMetadata.getRateLimit()).isEqualTo(RateLimit.NULL); - Usage usage = generationMetadata.getUsage(); + Usage usage = chatResponseMetadata.getUsage(); assertThat(usage).isNotNull(); assertThat(usage).isNotEqualTo(Usage.NULL); @@ -122,11 +124,11 @@ private void assertGenerationMetadata(ChatResponse response) { private void assertChoiceMetadata(Generation generation) { - ChoiceMetadata choiceMetadata = generation.getChoiceMetadata(); + ChatGenerationMetadata chatGenerationMetadata = generation.getMetadata(); - assertThat(choiceMetadata).isNotNull(); - assertThat(choiceMetadata.getFinishReason()).isEqualTo("stop"); - assertContentFilterResults(choiceMetadata.getContentFilterMetadata()); + assertThat(chatGenerationMetadata).isNotNull(); + assertThat(chatGenerationMetadata.getFinishReason()).isEqualTo("stop"); + assertContentFilterResults(chatGenerationMetadata.getContentFilterMetadata()); } private void assertContentFilterResultsForPrompt(ContentFilterResultDetailsForPrompt contentFilterResultForPrompt, diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/BedrockUsage.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/BedrockUsage.java index 88427d9e025..9fa68242a01 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/BedrockUsage.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/BedrockUsage.java @@ -17,7 +17,7 @@ package org.springframework.ai.bedrock; import org.springframework.ai.bedrock.api.AbstractBedrockApi.AmazonBedrockInvocationMetrics; -import org.springframework.ai.metadata.Usage; +import org.springframework.ai.chat.metadata.Usage; import org.springframework.util.Assert; /** diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/MessageToPromptConverter.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/MessageToPromptConverter.java index 937819a022a..d01ecffddc5 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/MessageToPromptConverter.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/MessageToPromptConverter.java @@ -19,8 +19,8 @@ import java.util.List; import java.util.stream.Collectors; -import org.springframework.ai.prompt.messages.Message; -import org.springframework.ai.prompt.messages.MessageType; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.MessageType; /** * Converts a list of messages to a prompt for bedrock models. diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatClient.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatClient.java index 6c7dd7f39c1..a8b67702648 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatClient.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatClient.java @@ -20,6 +20,7 @@ import org.springframework.ai.chat.ChatClient; import org.springframework.ai.chat.ChatResponse; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import reactor.core.publisher.Flux; import org.springframework.ai.bedrock.MessageToPromptConverter; @@ -28,12 +29,11 @@ import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatResponse; import org.springframework.ai.chat.StreamingChatClient; import org.springframework.ai.chat.Generation; -import org.springframework.ai.metadata.ChoiceMetadata; -import org.springframework.ai.prompt.Prompt; +import org.springframework.ai.chat.prompt.Prompt; /** * Java {@link ChatClient} and {@link StreamingChatClient} for the Bedrock Anthropic chat - * model. + * generative. * * @author Christian Tzolov * @since 0.8.0 @@ -89,8 +89,8 @@ public BedrockAnthropicChatClient withAnthropicVersion(String anthropicVersion) } @Override - public ChatResponse generate(Prompt prompt) { - final String promptValue = MessageToPromptConverter.create().toPrompt(prompt.getMessages()); + public ChatResponse call(Prompt prompt) { + final String promptValue = MessageToPromptConverter.create().toPrompt(prompt.getInstructions()); AnthropicChatRequest request = AnthropicChatRequest.builder(promptValue) .withTemperature(this.temperature) @@ -109,7 +109,7 @@ public ChatResponse generate(Prompt prompt) { @Override public Flux generateStream(Prompt prompt) { - final String promptValue = MessageToPromptConverter.create().toPrompt(prompt.getMessages()); + final String promptValue = MessageToPromptConverter.create().toPrompt(prompt.getInstructions()); AnthropicChatRequest request = AnthropicChatRequest.builder(promptValue) .withTemperature(this.temperature) @@ -126,8 +126,8 @@ public Flux generateStream(Prompt prompt) { String stopReason = response.stopReason() != null ? response.stopReason() : null; var generation = new Generation(response.completion()); if (response.amazonBedrockInvocationMetrics() != null) { - generation = generation - .withChoiceMetadata(ChoiceMetadata.from(stopReason, response.amazonBedrockInvocationMetrics())); + generation = generation.withGenerationMetadata( + ChatGenerationMetadata.from(stopReason, response.amazonBedrockInvocationMetrics())); } return new ChatResponse(List.of(generation)); }); diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatClient.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatClient.java index 50012480737..a5c1e3e188c 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatClient.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatClient.java @@ -19,6 +19,7 @@ import java.util.List; import org.springframework.ai.chat.ChatResponse; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import reactor.core.publisher.Flux; import org.springframework.ai.bedrock.BedrockUsage; @@ -32,9 +33,8 @@ import org.springframework.ai.chat.ChatClient; import org.springframework.ai.chat.StreamingChatClient; import org.springframework.ai.chat.Generation; -import org.springframework.ai.metadata.ChoiceMetadata; -import org.springframework.ai.metadata.Usage; -import org.springframework.ai.prompt.Prompt; +import org.springframework.ai.chat.metadata.Usage; +import org.springframework.ai.chat.prompt.Prompt; /** * @author Christian Tzolov @@ -112,7 +112,7 @@ public BedrockCohereChatClient withTruncate(Truncate truncate) { } @Override - public ChatResponse generate(Prompt prompt) { + public ChatResponse call(Prompt prompt) { CohereChatResponse response = this.chatApi.chatCompletion(this.createRequest(prompt, false)); List generations = response.generations().stream().map(g -> { return new Generation(g.text()); @@ -127,15 +127,15 @@ public Flux generateStream(Prompt prompt) { if (g.isFinished()) { String finishReason = g.finishReason().name(); Usage usage = BedrockUsage.from(g.amazonBedrockInvocationMetrics()); - return new ChatResponse( - List.of(new Generation("").withChoiceMetadata(ChoiceMetadata.from(finishReason, usage)))); + return new ChatResponse(List + .of(new Generation("").withGenerationMetadata(ChatGenerationMetadata.from(finishReason, usage)))); } return new ChatResponse(List.of(new Generation(g.text()))); }); } private CohereChatRequest createRequest(Prompt prompt, boolean stream) { - final String promptValue = MessageToPromptConverter.create().toPrompt(prompt.getMessages()); + final String promptValue = MessageToPromptConverter.create().toPrompt(prompt.getInstructions()); return CohereChatRequest.builder(promptValue) .withTemperature(this.temperature) diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatClient.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatClient.java index e4fc8d69dbd..38e157394af 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatClient.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatClient.java @@ -28,13 +28,13 @@ import org.springframework.ai.chat.ChatClient; import org.springframework.ai.chat.StreamingChatClient; import org.springframework.ai.chat.Generation; -import org.springframework.ai.metadata.ChoiceMetadata; -import org.springframework.ai.metadata.Usage; -import org.springframework.ai.prompt.Prompt; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; +import org.springframework.ai.chat.metadata.Usage; +import org.springframework.ai.chat.prompt.Prompt; /** * Java {@link ChatClient} and {@link StreamingChatClient} for the Bedrock Llama2 chat - * model. + * generative. * * @author Christian Tzolov * @since 0.8.0 @@ -69,8 +69,8 @@ public BedrockLlama2ChatClient withMaxGenLen(Integer maxGenLen) { } @Override - public ChatResponse generate(Prompt prompt) { - final String promptValue = MessageToPromptConverter.create().toPrompt(prompt.getMessages()); + public ChatResponse call(Prompt prompt) { + final String promptValue = MessageToPromptConverter.create().toPrompt(prompt.getInstructions()); var request = Llama2ChatRequest.builder(promptValue) .withTemperature(this.temperature) @@ -80,14 +80,14 @@ public ChatResponse generate(Prompt prompt) { Llama2ChatResponse response = this.chatApi.chatCompletion(request); - return new ChatResponse(List.of(new Generation(response.generation()) - .withChoiceMetadata(ChoiceMetadata.from(response.stopReason().name(), extractUsage(response))))); + return new ChatResponse(List.of(new Generation(response.generation()).withGenerationMetadata( + ChatGenerationMetadata.from(response.stopReason().name(), extractUsage(response))))); } @Override public Flux generateStream(Prompt prompt) { - final String promptValue = MessageToPromptConverter.create().toPrompt(prompt.getMessages()); + final String promptValue = MessageToPromptConverter.create().toPrompt(prompt.getInstructions()); var request = Llama2ChatRequest.builder(promptValue) .withTemperature(this.temperature) @@ -100,7 +100,7 @@ public Flux generateStream(Prompt prompt) { return fluxResponse.map(response -> { String stopReason = response.stopReason() != null ? response.stopReason().name() : null; return new ChatResponse(List.of(new Generation(response.generation()) - .withChoiceMetadata(ChoiceMetadata.from(stopReason, extractUsage(response))))); + .withGenerationMetadata(ChatGenerationMetadata.from(stopReason, extractUsage(response))))); }); } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatClient.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatClient.java index 14a95d04aba..b2279ce836b 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatClient.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatClient.java @@ -19,6 +19,7 @@ import java.util.List; import org.springframework.ai.chat.ChatClient; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import reactor.core.publisher.Flux; import org.springframework.ai.bedrock.MessageToPromptConverter; @@ -29,9 +30,8 @@ import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.StreamingChatClient; import org.springframework.ai.chat.Generation; -import org.springframework.ai.metadata.ChoiceMetadata; -import org.springframework.ai.metadata.Usage; -import org.springframework.ai.prompt.Prompt; +import org.springframework.ai.chat.metadata.Usage; +import org.springframework.ai.chat.prompt.Prompt; /** * @author Christian Tzolov @@ -74,7 +74,7 @@ public BedrockTitanChatClient withStopSequences(List stopSequences) { } @Override - public ChatResponse generate(Prompt prompt) { + public ChatResponse call(Prompt prompt) { TitanChatResponse response = this.chatApi.chatCompletion(this.createRequest(prompt, false)); List generations = response.results().stream().map(result -> { return new Generation(result.outputText()); @@ -91,12 +91,13 @@ public Flux generateStream(Prompt prompt) { if (chunk.amazonBedrockInvocationMetrics() != null) { String completionReason = chunk.completionReason().name(); - generation = generation - .withChoiceMetadata(ChoiceMetadata.from(completionReason, chunk.amazonBedrockInvocationMetrics())); + generation = generation.withGenerationMetadata( + ChatGenerationMetadata.from(completionReason, chunk.amazonBedrockInvocationMetrics())); } else if (chunk.inputTextTokenCount() != null && chunk.totalOutputTextTokenCount() != null) { String completionReason = chunk.completionReason().name(); - generation = generation.withChoiceMetadata(ChoiceMetadata.from(completionReason, extractUsage(chunk))); + generation = generation + .withGenerationMetadata(ChatGenerationMetadata.from(completionReason, extractUsage(chunk))); } return new ChatResponse(List.of(generation)); @@ -104,7 +105,7 @@ else if (chunk.inputTextTokenCount() != null && chunk.totalOutputTextTokenCount( } private TitanChatRequest createRequest(Prompt prompt, boolean stream) { - final String promptValue = MessageToPromptConverter.create().toPrompt(prompt.getMessages()); + final String promptValue = MessageToPromptConverter.create().toPrompt(prompt.getInstructions()); return TitanChatRequest.builder(promptValue) .withTemperature(this.temperature) diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatClientIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatClientIT.java index 14ecad3f9e4..fd342892766 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatClientIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatClientIT.java @@ -9,6 +9,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.springframework.ai.chat.ChatResponse; +import org.springframework.ai.chat.messages.AssistantMessage; import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; import software.amazon.awssdk.regions.Region; @@ -17,11 +18,11 @@ import org.springframework.ai.parser.BeanOutputParser; import org.springframework.ai.parser.ListOutputParser; import org.springframework.ai.parser.MapOutputParser; -import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.PromptTemplate; -import org.springframework.ai.prompt.SystemPromptTemplate; -import org.springframework.ai.prompt.messages.Message; -import org.springframework.ai.prompt.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.SpringBootConfiguration; @@ -52,9 +53,9 @@ void roleTest() { Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - ChatResponse response = client.generate(prompt); + ChatResponse response = client.call(prompt); - assertThat(response.getGeneration().getContent()).contains("Blackbeard"); + assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); } @Test @@ -70,9 +71,9 @@ void outputParser() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "ice cream flavors.", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = this.client.generate(prompt).getGeneration(); + Generation generation = this.client.call(prompt).getResult(); - List list = outputParser.parse(generation.getContent()); + List list = outputParser.parse(generation.getOutput().getContent()); assertThat(list).hasSize(5); } @@ -88,9 +89,9 @@ void mapOutputParser() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = client.generate(prompt).getGeneration(); + Generation generation = client.call(prompt).getResult(); - Map result = outputParser.parse(generation.getContent()); + Map result = outputParser.parse(generation.getOutput().getContent()); assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); } @@ -112,9 +113,9 @@ void beanOutputParserRecords() { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = client.generate(prompt).getGeneration(); + Generation generation = client.call(prompt).getResult(); - ActorsFilmsRecord actorsFilms = outputParser.parse(generation.getContent()); + ActorsFilmsRecord actorsFilms = outputParser.parse(generation.getOutput().getContent()); assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); assertThat(actorsFilms.movies()).hasSize(5); } @@ -137,9 +138,10 @@ void beanStreamOutputParserRecords() { .collectList() .block() .stream() - .map(ChatResponse::getGenerations) + .map(ChatResponse::getResults) .flatMap(List::stream) - .map(Generation::getContent) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) .collect(Collectors.joining()); ActorsFilmsRecord actorsFilms = outputParser.parse(generationTextFromStream); diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatClientIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatClientIT.java index bb4ffc32fe5..9fc97ee2c6f 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatClientIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatClientIT.java @@ -8,6 +8,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.chat.messages.AssistantMessage; import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; import software.amazon.awssdk.regions.Region; @@ -18,11 +19,11 @@ import org.springframework.ai.parser.BeanOutputParser; import org.springframework.ai.parser.ListOutputParser; import org.springframework.ai.parser.MapOutputParser; -import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.PromptTemplate; -import org.springframework.ai.prompt.SystemPromptTemplate; -import org.springframework.ai.prompt.messages.Message; -import org.springframework.ai.prompt.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.SpringBootConfiguration; @@ -53,8 +54,8 @@ void roleTest() { SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", name, "voice", voice)); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - ChatResponse response = client.generate(prompt); - assertThat(response.getGeneration().getContent()).contains("Blackbeard"); + ChatResponse response = client.call(prompt); + assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); } @Test @@ -70,9 +71,9 @@ void outputParser() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "ice cream flavors.", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = this.client.generate(prompt).getGeneration(); + Generation generation = this.client.call(prompt).getResult(); - List list = outputParser.parse(generation.getContent()); + List list = outputParser.parse(generation.getOutput().getContent()); assertThat(list).hasSize(5); } @@ -89,9 +90,9 @@ void mapOutputParser() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = client.generate(prompt).getGeneration(); + Generation generation = client.call(prompt).getResult(); - Map result = outputParser.parse(generation.getContent()); + Map result = outputParser.parse(generation.getOutput().getContent()); assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); } @@ -112,9 +113,9 @@ void beanOutputParserRecords() { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = client.generate(prompt).getGeneration(); + Generation generation = client.call(prompt).getResult(); - ActorsFilmsRecord actorsFilms = outputParser.parse(generation.getContent()); + ActorsFilmsRecord actorsFilms = outputParser.parse(generation.getOutput().getContent()); assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); assertThat(actorsFilms.movies()).hasSize(5); } @@ -137,9 +138,10 @@ void beanStreamOutputParserRecords() { .collectList() .block() .stream() - .map(ChatResponse::getGenerations) + .map(ChatResponse::getResults) .flatMap(List::stream) - .map(Generation::getContent) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) .collect(Collectors.joining()); ActorsFilmsRecord actorsFilms = outputParser.parse(generationTextFromStream); diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatClientIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatClientIT.java index 4cbbc8d072e..b1ec0458e17 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatClientIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatClientIT.java @@ -10,6 +10,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.springframework.ai.chat.ChatResponse; +import org.springframework.ai.chat.messages.AssistantMessage; import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; import software.amazon.awssdk.regions.Region; @@ -19,11 +20,11 @@ import org.springframework.ai.parser.BeanOutputParser; import org.springframework.ai.parser.ListOutputParser; import org.springframework.ai.parser.MapOutputParser; -import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.PromptTemplate; -import org.springframework.ai.prompt.SystemPromptTemplate; -import org.springframework.ai.prompt.messages.Message; -import org.springframework.ai.prompt.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.SpringBootConfiguration; @@ -54,9 +55,9 @@ void roleTest() { Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - ChatResponse response = client.generate(prompt); + ChatResponse response = client.call(prompt); - assertThat(response.getGeneration().getContent()).contains("Blackbeard"); + assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); } @Disabled("TODO: Fix the parser instructions to return the correct format") @@ -73,9 +74,9 @@ void outputParser() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "ice cream flavors.", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = this.client.generate(prompt).getGeneration(); + Generation generation = this.client.call(prompt).getResult(); - List list = outputParser.parse(generation.getContent()); + List list = outputParser.parse(generation.getOutput().getContent()); assertThat(list).hasSize(5); } @@ -91,9 +92,9 @@ void mapOutputParser() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = client.generate(prompt).getGeneration(); + Generation generation = client.call(prompt).getResult(); - Map result = outputParser.parse(generation.getContent()); + Map result = outputParser.parse(generation.getOutput().getContent()); assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); } @@ -116,9 +117,9 @@ void beanOutputParserRecords() { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = client.generate(prompt).getGeneration(); + Generation generation = client.call(prompt).getResult(); - ActorsFilmsRecord actorsFilms = outputParser.parse(generation.getContent()); + ActorsFilmsRecord actorsFilms = outputParser.parse(generation.getOutput().getContent()); assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); assertThat(actorsFilms.movies()).hasSize(5); } @@ -142,9 +143,10 @@ void beanStreamOutputParserRecords() { .collectList() .block() .stream() - .map(ChatResponse::getGenerations) + .map(ChatResponse::getResults) .flatMap(List::stream) - .map(Generation::getContent) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) .collect(Collectors.joining()); ActorsFilmsRecord actorsFilms = outputParser.parse(generationTextFromStream); diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanChatClientIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanChatClientIT.java index 2f96bb84f23..1c07073cc0e 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanChatClientIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanChatClientIT.java @@ -10,6 +10,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.springframework.ai.chat.ChatResponse; +import org.springframework.ai.chat.messages.AssistantMessage; import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; import software.amazon.awssdk.regions.Region; @@ -19,11 +20,11 @@ import org.springframework.ai.parser.BeanOutputParser; import org.springframework.ai.parser.ListOutputParser; import org.springframework.ai.parser.MapOutputParser; -import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.PromptTemplate; -import org.springframework.ai.prompt.SystemPromptTemplate; -import org.springframework.ai.prompt.messages.Message; -import org.springframework.ai.prompt.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.SpringBootConfiguration; @@ -54,8 +55,8 @@ void roleTest() { SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", name, "voice", voice)); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - ChatResponse response = client.generate(prompt); - assertThat(response.getGeneration().getContent()).contains("Blackbeard"); + ChatResponse response = client.call(prompt); + assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); } @Disabled("TODO: Fix the parser instructions to return the correct format") @@ -72,9 +73,9 @@ void outputParser() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "ice cream flavors.", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = this.client.generate(prompt).getGeneration(); + Generation generation = this.client.call(prompt).getResult(); - List list = outputParser.parse(generation.getContent()); + List list = outputParser.parse(generation.getOutput().getContent()); assertThat(list).hasSize(5); } @@ -93,9 +94,9 @@ void mapOutputParser() { Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = client.generate(prompt).getGeneration(); + Generation generation = client.call(prompt).getResult(); - Map result = outputParser.parse(generation.getContent()); + Map result = outputParser.parse(generation.getOutput().getContent()); assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); } @@ -117,9 +118,9 @@ void beanOutputParserRecords() { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = client.generate(prompt).getGeneration(); + Generation generation = client.call(prompt).getResult(); - ActorsFilmsRecord actorsFilms = outputParser.parse(generation.getContent()); + ActorsFilmsRecord actorsFilms = outputParser.parse(generation.getOutput().getContent()); assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); assertThat(actorsFilms.movies()).hasSize(5); } @@ -143,9 +144,10 @@ void beanStreamOutputParserRecords() { .collectList() .block() .stream() - .map(ChatResponse::getGenerations) + .map(ChatResponse::getResults) .flatMap(List::stream) - .map(Generation::getContent) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) .collect(Collectors.joining()); ActorsFilmsRecord actorsFilms = outputParser.parse(generationTextFromStream); diff --git a/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceChatClient.java b/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceChatClient.java index 8b32bb80c19..a279f79372b 100644 --- a/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceChatClient.java +++ b/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceChatClient.java @@ -32,7 +32,7 @@ import org.springframework.ai.huggingface.model.GenerateParameters; import org.springframework.ai.huggingface.model.GenerateRequest; import org.springframework.ai.huggingface.model.GenerateResponse; -import org.springframework.ai.prompt.Prompt; +import org.springframework.ai.chat.prompt.Prompt; /** * An implementation of {@link ChatClient} that interfaces with HuggingFace Inference @@ -86,7 +86,7 @@ public HuggingfaceChatClient(final String apiToken, String basePath) { * @return ChatResponse containing the generated text and other related details. */ @Override - public ChatResponse generate(Prompt prompt) { + public ChatResponse call(Prompt prompt) { GenerateRequest generateRequest = new GenerateRequest(); generateRequest.setInputs(prompt.getContents()); GenerateParameters generateParameters = new GenerateParameters(); diff --git a/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/client/ClientIT.java b/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/client/ClientIT.java index 7c07a61e340..05f814a5fc8 100644 --- a/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/client/ClientIT.java +++ b/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/client/ClientIT.java @@ -22,7 +22,7 @@ import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.huggingface.HuggingfaceChatClient; -import org.springframework.ai.prompt.Prompt; +import org.springframework.ai.chat.prompt.Prompt; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.test.context.SpringBootTest; @@ -47,8 +47,8 @@ void helloWorldCompletion() { [/INST] """; Prompt prompt = new Prompt(mistral7bInstruct); - ChatResponse chatResponse = huggingfaceChatClient.generate(prompt); - assertThat(chatResponse.getGeneration().getContent()).isNotEmpty(); + ChatResponse chatResponse = huggingfaceChatClient.call(prompt); + assertThat(chatResponse.getResult().getOutput().getContent()).isNotEmpty(); String expectedResponse = """ ```json { @@ -57,9 +57,9 @@ void helloWorldCompletion() { "address": "#1 Samuel St." } ```"""; - assertThat(chatResponse.getGeneration().getContent()).isEqualTo(expectedResponse); - assertThat(chatResponse.getGeneration().getProperties()).containsKey("generated_tokens"); - assertThat(chatResponse.getGeneration().getProperties()).containsEntry("generated_tokens", 39); + assertThat(chatResponse.getResult().getOutput().getContent()).isEqualTo(expectedResponse); + assertThat(chatResponse.getResult().getOutput().getProperties()).containsKey("generated_tokens"); + assertThat(chatResponse.getResult().getOutput().getProperties()).containsEntry("generated_tokens", 39); } diff --git a/models/spring-ai-ollama/pom.xml b/models/spring-ai-ollama/pom.xml index 0da9e9a5353..ad143f9afb0 100644 --- a/models/spring-ai-ollama/pom.xml +++ b/models/spring-ai-ollama/pom.xml @@ -21,6 +21,12 @@ + + + org.springframework.boot + spring-boot + + org.springframework.ai spring-ai-core diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatClient.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatClient.java index 3c72a5d42a1..e1835c67ac1 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatClient.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatClient.java @@ -16,25 +16,30 @@ package org.springframework.ai.ollama; +import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import reactor.core.publisher.Flux; import org.springframework.ai.chat.ChatClient; import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.Generation; import org.springframework.ai.chat.StreamingChatClient; -import org.springframework.ai.metadata.ChoiceMetadata; -import org.springframework.ai.metadata.Usage; +import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.ollama.api.OllamaApi; import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.ai.ollama.api.OllamaApi.ChatRequest; import org.springframework.ai.ollama.api.OllamaApi.Message.Role; -import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.messages.Message; -import org.springframework.ai.prompt.messages.MessageType; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.MessageType; /** * {@link ChatClient} implementation for {@literal Ollma}. @@ -58,6 +63,8 @@ public class OllamaChatClient implements ChatClient, StreamingChatClient { private Map clientOptions; + private final static ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + public OllamaChatClient(OllamaApi chatApi) { this.chatApi = chatApi; } @@ -78,12 +85,13 @@ public OllamaChatClient withOptions(OllamaOptions options) { } @Override - public ChatResponse generate(Prompt prompt) { + public ChatResponse call(Prompt prompt) { OllamaApi.ChatResponse response = this.chatApi.chat(request(prompt, this.model, false)); var generator = new Generation(response.message().content()); if (response.promptEvalCount() != null && response.evalCount() != null) { - generator = generator.withChoiceMetadata(ChoiceMetadata.from("unknown", extractUsage(response))); + generator = generator + .withGenerationMetadata(ChatGenerationMetadata.from("unknown", extractUsage(response))); } return new ChatResponse(List.of(generator)); } @@ -97,7 +105,8 @@ public Flux generateStream(Prompt prompt) { Generation generation = (chunk.message() != null) ? new Generation(chunk.message().content()) : new Generation(""); if (Boolean.TRUE.equals(chunk.done())) { - generation = generation.withChoiceMetadata(ChoiceMetadata.from("unknown", extractUsage(chunk))); + generation = generation + .withGenerationMetadata(ChatGenerationMetadata.from("unknown", extractUsage(chunk))); } return new ChatResponse(List.of(generation)); }); @@ -120,20 +129,57 @@ public Long getGenerationTokens() { private OllamaApi.ChatRequest request(Prompt prompt, String model, boolean stream) { - List ollamaMessages = prompt.getMessages() + List ollamaMessages = prompt.getInstructions() .stream() .filter(message -> message.getMessageType() == MessageType.USER || message.getMessageType() == MessageType.ASSISTANT) .map(m -> OllamaApi.Message.builder(toRole(m)).withContent(m.getContent()).build()) .toList(); + // runtime options + Map promptOptions = objectToMap(prompt.getOptions()); + Map clientOptionsToUse = merge(promptOptions, this.clientOptions, HashMap.class); + return ChatRequest.builder(model) .withStream(stream) .withMessages(ollamaMessages) - .withOptions(this.clientOptions) + .withOptions(clientOptionsToUse) .build(); } + public static Map objectToMap(Object source) { + try { + String json = OBJECT_MAPPER.writeValueAsString(source); + return OBJECT_MAPPER.readValue(json, new TypeReference>() { + }); + } + catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + public static T mapToClass(Map source, Class clazz) { + try { + String json = OBJECT_MAPPER.writeValueAsString(source); + return OBJECT_MAPPER.readValue(json, clazz); + } + catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + public static T merge(Object source, Object target, Class clazz) { + Map sourceMap = objectToMap(source); + Map targetMap = objectToMap(target); + + targetMap.putAll(sourceMap.entrySet() + .stream() + .filter(e -> e.getValue() != null) + .collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue()))); + + return mapToClass(targetMap, clazz); + } + private OllamaApi.Message.Role toRole(Message message) { switch (message.getMessageType()) { diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java index 172b08fe52b..5bc3f6c6f3c 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java @@ -25,6 +25,7 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; +import org.springframework.ai.chat.ChatOptions; /** * Helper class for creating strongly-typed Ollama options. @@ -38,7 +39,7 @@ * Types */ @JsonInclude(Include.NON_NULL) -public class OllamaOptions { +public class OllamaOptions implements ChatOptions { // @formatter:off /** diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatClientIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatClientIT.java index 88a184e6f7e..84617fea0bc 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatClientIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatClientIT.java @@ -11,6 +11,8 @@ import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.ChatOptionsBuilder; +import org.springframework.ai.chat.messages.AssistantMessage; import org.testcontainers.containers.GenericContainer; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; @@ -22,11 +24,11 @@ import org.springframework.ai.parser.BeanOutputParser; import org.springframework.ai.parser.ListOutputParser; import org.springframework.ai.parser.MapOutputParser; -import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.PromptTemplate; -import org.springframework.ai.prompt.SystemPromptTemplate; -import org.springframework.ai.prompt.messages.Message; -import org.springframework.ai.prompt.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; @@ -51,7 +53,7 @@ class OllamaChatClientIT { @BeforeAll public static void beforeAll() throws IOException, InterruptedException { - logger.info("Start pulling the '" + MODEL + " ' model ... would take several minutes ..."); + logger.info("Start pulling the '" + MODEL + " ' generative ... would take several minutes ..."); ollamaContainer.execInContainer("ollama", "pull", MODEL); logger.info(MODEL + " pulling competed!"); @@ -72,9 +74,16 @@ void roleTest() { UserMessage userMessage = new UserMessage("Tell me about 5 famous pirates from the Golden Age of Piracy."); - Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - ChatResponse response = client.generate(prompt); - assertThat(response.getGeneration().getContent()).contains("Blackbeard"); + // portable/generic options + var chatOptionsBuilder = ChatOptionsBuilder.builder(); + + // ollama specific options + var ollamaOptions = new OllamaOptions().withLowVRAM(true); + + Prompt prompt = new Prompt(List.of(userMessage, systemMessage), + chatOptionsBuilder.withTemperature(0.7f).build()); + ChatResponse response = client.call(prompt); + assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); } @Disabled("TODO: Fix the parser instructions to return the correct format") @@ -91,9 +100,9 @@ void outputParser() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "ice cream flavors.", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = this.client.generate(prompt).getGeneration(); + Generation generation = this.client.call(prompt).getResult(); - List list = outputParser.parse(generation.getContent()); + List list = outputParser.parse(generation.getOutput().getContent()); assertThat(list).hasSize(5); } @@ -112,9 +121,9 @@ void mapOutputParser() { Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = client.generate(prompt).getGeneration(); + Generation generation = client.call(prompt).getResult(); - Map result = outputParser.parse(generation.getContent()); + Map result = outputParser.parse(generation.getOutput().getContent()); assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); } @@ -136,9 +145,9 @@ void beanOutputParserRecords() { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = client.generate(prompt).getGeneration(); + Generation generation = client.call(prompt).getResult(); - ActorsFilmsRecord actorsFilms = outputParser.parse(generation.getContent()); + ActorsFilmsRecord actorsFilms = outputParser.parse(generation.getOutput().getContent()); assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); assertThat(actorsFilms.movies()).hasSize(5); } @@ -162,9 +171,10 @@ void beanStreamOutputParserRecords() { .collectList() .block() .stream() - .map(ChatResponse::getGenerations) + .map(ChatResponse::getResults) .flatMap(List::stream) - .map(Generation::getContent) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) .collect(Collectors.joining()); ActorsFilmsRecord actorsFilms = outputParser.parse(generationTextFromStream); diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingClientIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingClientIT.java index 393366092b1..8f066d05b95 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingClientIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingClientIT.java @@ -36,7 +36,7 @@ class OllamaEmbeddingClientIT { @BeforeAll public static void beforeAll() throws IOException, InterruptedException { - logger.info("Start pulling the 'orca-mini' model (3GB) ... would take several minutes ..."); + logger.info("Start pulling the 'orca-mini' generative (3GB) ... would take several minutes ..."); ollamaContainer.execInContainer("ollama", "pull", "orca-mini"); logger.info("orca-mini pulling competed!"); diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiIT.java index ae5ed2ea1c0..da34fa83997 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiIT.java @@ -57,7 +57,7 @@ public class OllamaApiIT { @BeforeAll public static void beforeAll() throws IOException, InterruptedException { - logger.info("Start pulling the 'orca-mini' model (3GB) ... would take several minutes ..."); + logger.info("Start pulling the 'orca-mini' generative (3GB) ... would take several minutes ..."); ollamaContainer.execInContainer("ollama", "pull", "orca-mini"); logger.info("orca-mini pulling competed!"); diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaOptionsTests.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaModelOptionsTests.java similarity index 97% rename from models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaOptionsTests.java rename to models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaModelOptionsTests.java index 60c4f098f98..6320854ec35 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaOptionsTests.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaModelOptionsTests.java @@ -25,7 +25,7 @@ /** * @author Christian Tzolov */ -public class OllamaOptionsTests { +public class OllamaModelOptionsTests { @Test public void testOptions() { diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java index 9ea1617411e..64ebc1c45d0 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java @@ -28,17 +28,17 @@ import org.springframework.ai.chat.StreamingChatClient; import org.springframework.ai.chat.Generation; -import org.springframework.ai.metadata.ChoiceMetadata; -import org.springframework.ai.metadata.RateLimit; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; +import org.springframework.ai.chat.metadata.RateLimit; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage; import org.springframework.ai.openai.api.OpenAiApi.OpenAiApiClientErrorException; import org.springframework.ai.openai.api.OpenAiApi.OpenAiApiException; -import org.springframework.ai.openai.metadata.OpenAiGenerationMetadata; +import org.springframework.ai.openai.metadata.OpenAiChatResponseMetadata; import org.springframework.ai.openai.metadata.support.OpenAiResponseHeaderExtractor; -import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.messages.Message; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.messages.Message; import org.springframework.http.ResponseEntity; import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; @@ -94,10 +94,10 @@ public void setTemperature(Double temperature) { } @Override - public ChatResponse generate(Prompt prompt) { + public ChatResponse call(Prompt prompt) { return this.retryTemplate.execute(ctx -> { - List messages = prompt.getMessages(); + List messages = prompt.getInstructions(); List chatCompletionMessages = messages.stream() .map(m -> new ChatCompletionMessage(m.getContent(), @@ -118,18 +118,18 @@ public ChatResponse generate(Prompt prompt) { List generations = chatCompletion.choices().stream().map(choice -> { return new Generation(choice.message().content(), Map.of("role", choice.message().role().name())) - .withChoiceMetadata(ChoiceMetadata.from(choice.finishReason().name(), null)); + .withGenerationMetadata(ChatGenerationMetadata.from(choice.finishReason().name(), null)); }).toList(); return new ChatResponse(generations, - OpenAiGenerationMetadata.from(completionEntity.getBody()).withRateLimit(rateLimits)); + OpenAiChatResponseMetadata.from(completionEntity.getBody()).withRateLimit(rateLimits)); }); } @Override public Flux generateStream(Prompt prompt) { return this.retryTemplate.execute(ctx -> { - List messages = prompt.getMessages(); + List messages = prompt.getInstructions(); List chatCompletionMessages = messages.stream() .map(m -> new ChatCompletionMessage(m.getContent(), @@ -153,7 +153,7 @@ public Flux generateStream(Prompt prompt) { var generation = new Generation(choice.delta().content(), Map.of("role", roleMap.get(chunkId))); if (choice.finishReason() != null) { generation = generation - .withChoiceMetadata(ChoiceMetadata.from(choice.finishReason().name(), null)); + .withGenerationMetadata(ChatGenerationMetadata.from(choice.finishReason().name(), null)); } return generation; }).toList(); diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingClient.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingClient.java index 5524472b1a7..627c476902f 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingClient.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingClient.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai; import java.time.Duration; diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageClient.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageClient.java new file mode 100644 index 00000000000..87b6a2fcde3 --- /dev/null +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageClient.java @@ -0,0 +1,148 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.image.*; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.openai.api.*; +import org.springframework.ai.openai.metadata.OpenAiImageGenerationMetadata; +import org.springframework.ai.openai.metadata.OpenAiImageResponseMetadata; +import org.springframework.http.ResponseEntity; +import org.springframework.retry.support.RetryTemplate; +import org.springframework.util.Assert; + +import java.time.Duration; +import java.util.List; + +public class OpenAiImageClient implements ImageClient { + + private final Logger logger = LoggerFactory.getLogger(getClass()); + + private OpenAiImageOptions options; + + private final OpenAiImageApi openAiImageApi; + + public final RetryTemplate retryTemplate = RetryTemplate.builder() + .maxAttempts(10) + .retryOn(OpenAiApi.OpenAiApiException.class) + .exponentialBackoff(Duration.ofMillis(2000), 5, Duration.ofMillis(3 * 60000)) + .build(); + + public OpenAiImageClient(OpenAiImageApi openAiImageApi) { + Assert.notNull(openAiImageApi, "OpenAiImageApi must not be null"); + this.openAiImageApi = openAiImageApi; + } + + public OpenAiImageOptions getOptions() { + return options; + } + + @Override + public ImageResponse call(ImagePrompt imagePrompt) { + return this.retryTemplate.execute(ctx -> { + ImageOptions runtimeOptions = imagePrompt.getOptions(); + OpenAiImageOptions imageOptionsToUse = updateImageOptions(imagePrompt.getOptions()); + + // Merge the runtime options passed via the prompt with the + // StabilityAiImageClient + // options configured via Autoconfiguration. + // Runtime options overwrite StabilityAiImageClient options + OpenAiImageOptions optionsToUse = ModelOptionsUtils.merge(runtimeOptions, this.options, + OpenAiImageOptionsImpl.class); + + // Copy the org.springframework.ai.model derived ImagePrompt and ImageOptions + // data + // types to the data types used in OpenAiImageApi + String instructions = imagePrompt.getInstructions().get(0).getText(); + String size; + if (imageOptionsToUse.getWidth() != null && imageOptionsToUse.getHeight() != null) { + size = imageOptionsToUse.getWidth() + "x" + imageOptionsToUse.getHeight(); + } + else { + size = null; + } + OpenAiImageApi.OpenAiImageRequest openAiImageRequest = new OpenAiImageApi.OpenAiImageRequest(instructions, + imageOptionsToUse.getModel(), imageOptionsToUse.getN(), imageOptionsToUse.getQuality(), size, + imageOptionsToUse.getResponseFormat(), imageOptionsToUse.getStyle(), imageOptionsToUse.getUser()); + + // Make the request + ResponseEntity imageResponseEntity = this.openAiImageApi + .createImage(openAiImageRequest); + + // Convert to org.springframework.ai.model derived ImageResponse data type + return convertResponse(imageResponseEntity, openAiImageRequest); + + }); + } + + private ImageResponse convertResponse(ResponseEntity imageResponseEntity, + OpenAiImageApi.OpenAiImageRequest openAiImageRequest) { + OpenAiImageApi.OpenAiImageResponse imageApiResponse = imageResponseEntity.getBody(); + if (imageApiResponse == null) { + logger.warn("No image response returned for request: {}", openAiImageRequest); + return new ImageResponse(List.of()); + } + + List imageGenerationList = imageApiResponse.data().stream().map(entry -> { + return new ImageGeneration(new Image(entry.url(), entry.b64Json()), + new OpenAiImageGenerationMetadata(entry.revisedPrompt())); + }).toList(); + + ImageResponseMetadata openAiImageResponseMetadata = OpenAiImageResponseMetadata.from(imageApiResponse); + return new ImageResponse(imageGenerationList, openAiImageResponseMetadata); + } + + private OpenAiImageOptions updateImageOptions(ImageOptions runtimeImageOptions) { + OpenAiImageOptionsBuilder openAiImageOptionsBuilder = OpenAiImageOptionsBuilder.builder(); + if (runtimeImageOptions != null) { + // Handle portable image options + if (runtimeImageOptions.getN() != null) { + openAiImageOptionsBuilder.withN(runtimeImageOptions.getN()); + } + if (runtimeImageOptions.getModel() != null) { + openAiImageOptionsBuilder.withModel(runtimeImageOptions.getModel()); + } + if (runtimeImageOptions.getResponseFormat() != null) { + openAiImageOptionsBuilder.withResponseFormat(runtimeImageOptions.getResponseFormat()); + } + if (runtimeImageOptions.getWidth() != null) { + openAiImageOptionsBuilder.withWidth(runtimeImageOptions.getWidth()); + } + if (runtimeImageOptions.getHeight() != null) { + openAiImageOptionsBuilder.withHeight(runtimeImageOptions.getHeight()); + } + // Handle OpenAI specific image options + if (runtimeImageOptions instanceof OpenAiImageOptions) { + OpenAiImageOptions runtimeOpenAiImageOptions = (OpenAiImageOptions) runtimeImageOptions; + if (runtimeOpenAiImageOptions.getQuality() != null) { + openAiImageOptionsBuilder.withQuality(runtimeOpenAiImageOptions.getQuality()); + } + if (runtimeOpenAiImageOptions.getStyle() != null) { + openAiImageOptionsBuilder.withStyle(runtimeOpenAiImageOptions.getStyle()); + } + if (runtimeOpenAiImageOptions.getUser() != null) { + openAiImageOptionsBuilder.withUser(runtimeOpenAiImageOptions.getUser()); + } + } + } + OpenAiImageOptions updatedOpenAiImageOptions = openAiImageOptionsBuilder.build(); + return updatedOpenAiImageOptions; + } + +} diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java index a31c028872a..55cda04321a 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java @@ -69,7 +69,7 @@ public OpenAiApi(String openAiToken) { } /** - * Create an new chat completion api. + * Create a new chat completion api. * * @param baseUrl api base URL. * @param openAiToken OpenAI apiKey. diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageApi.java new file mode 100644 index 00000000000..5a2800087b4 --- /dev/null +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageApi.java @@ -0,0 +1,226 @@ +package org.springframework.ai.openai.api; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; +import java.util.function.Consumer; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.ObjectMapper; + +import org.springframework.ai.openai.api.OpenAiApi.OpenAiApiClientErrorException; +import org.springframework.ai.openai.api.OpenAiApi.OpenAiApiException; +import org.springframework.ai.openai.api.OpenAiApi.ResponseError; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.http.client.ClientHttpResponse; +import org.springframework.util.Assert; +import org.springframework.web.client.ResponseErrorHandler; +import org.springframework.web.client.RestClient; + +public class OpenAiImageApi { + + private static final String DEFAULT_BASE_URL = "https://api.openai.com"; + + public static final String DEFAULT_IMAGE_MODEL = "dall-e-2"; + + // Assuming RestClient and WebClient are properly defined somewhere + private final RestClient restClient; + + private final ObjectMapper objectMapper; + + /** + * Create a new OpenAI Image api with base URL set to https://api.openai.com + * @param openAiToken OpenAI apiKey. + */ + public OpenAiImageApi(String openAiToken) { + this(DEFAULT_BASE_URL, openAiToken, RestClient.builder()); + } + + public OpenAiImageApi(String baseUrl, String openAiToken, RestClient.Builder restClientBuilder) { + + this.objectMapper = new ObjectMapper().configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); + + Consumer jsonContentHeaders = headers -> { + headers.setBearerAuth(openAiToken); + headers.setContentType(MediaType.APPLICATION_JSON); + }; + + var responseErrorHandler = new ResponseErrorHandler() { + + @Override + public boolean hasError(ClientHttpResponse response) throws IOException { + return response.getStatusCode().isError(); + } + + @Override + public void handleError(ClientHttpResponse response) throws IOException { + if (response.getStatusCode().isError()) { + if (response.getStatusCode().is4xxClientError()) { + throw new OpenAiApiClientErrorException(String.format("%s - %s", + response.getStatusCode().value(), + OpenAiImageApi.this.objectMapper.readValue(response.getBody(), ResponseError.class))); + } + throw new OpenAiApiException(String.format("%s - %s", response.getStatusCode().value(), + OpenAiImageApi.this.objectMapper.readValue(response.getBody(), ResponseError.class))); + } + } + }; + + this.restClient = restClientBuilder.baseUrl(baseUrl) + .defaultHeaders(jsonContentHeaders) + .defaultStatusHandler(responseErrorHandler) + .build(); + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + public static class OpenAiImageRequest { + + @JsonProperty("prompt") + private String prompt; + + @JsonProperty("model") + private String model = DEFAULT_IMAGE_MODEL; + + @JsonProperty("n") + private Integer n; + + @JsonProperty("quality") + private String quality; + + @JsonProperty("response_format") + private String responseFormat; + + @JsonProperty("size") + private String size; + + @JsonProperty("style") + private String style; + + @JsonProperty("user") + private String user; + + public OpenAiImageRequest() { + } + + public OpenAiImageRequest(String prompt, String model, Integer n, String quality, String size, + String responseFormat, String style, String user) { + this.prompt = prompt; + this.model = model; + this.n = n; + this.quality = quality; + this.size = size; + this.responseFormat = responseFormat; + this.style = style; + this.user = user; + } + + public String getPrompt() { + return prompt; + } + + public void setPrompt(String prompt) { + this.prompt = prompt; + } + + public String getModel() { + return model; + } + + public void setModel(String model) { + this.model = model; + } + + public Integer getN() { + return n; + } + + public void setN(Integer n) { + this.n = n; + } + + public String getQuality() { + return quality; + } + + public void setQuality(String quality) { + this.quality = quality; + } + + public String getSize() { + return size; + } + + public void setSize(String size) { + this.size = size; + } + + public String getResponseFormat() { + return responseFormat; + } + + public void setResponseFormat(String responseFormat) { + this.responseFormat = responseFormat; + } + + public String getStyle() { + return style; + } + + public void setStyle(String style) { + this.style = style; + } + + public String getUser() { + return user; + } + + public void setUser(String user) { + this.user = user; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (!(o instanceof OpenAiImageRequest that)) + return false; + return Objects.equals(prompt, that.prompt) && Objects.equals(model, that.model) && Objects.equals(n, that.n) + && Objects.equals(quality, that.quality) && Objects.equals(size, that.size) + && Objects.equals(responseFormat, that.responseFormat) && Objects.equals(style, that.style) + && Objects.equals(user, that.user); + } + + @Override + public int hashCode() { + return Objects.hash(prompt, model, n, quality, size, responseFormat, style, user); + } + + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + public record OpenAiImageResponse(@JsonProperty("created") Long created, @JsonProperty("data") List data) { + + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + public record Data(@JsonProperty("url") String url, @JsonProperty("b64_json") String b64Json, + @JsonProperty("revised_prompt") String revisedPrompt) { + + } + + public ResponseEntity createImage(OpenAiImageRequest openAiImageRequest) { + Assert.notNull(openAiImageRequest, "Image request cannot be null."); + Assert.hasLength(openAiImageRequest.getPrompt(), "Prompt cannot be empty."); + + return this.restClient.post() + .uri("v1/images/generations") + .body(openAiImageRequest) + .retrieve() + .toEntity(OpenAiImageResponse.class); + } + +} diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageOptions.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageOptions.java new file mode 100644 index 00000000000..327f5dcbb56 --- /dev/null +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageOptions.java @@ -0,0 +1,13 @@ +package org.springframework.ai.openai.api; + +import org.springframework.ai.image.ImageOptions; + +public interface OpenAiImageOptions extends ImageOptions { + + String getQuality(); + + String getStyle(); + + String getUser(); + +} diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageOptionsBuilder.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageOptionsBuilder.java new file mode 100644 index 00000000000..b896a2f9e5e --- /dev/null +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageOptionsBuilder.java @@ -0,0 +1,59 @@ +package org.springframework.ai.openai.api; + +public class OpenAiImageOptionsBuilder { + + private final OpenAiImageOptionsImpl options; + + private OpenAiImageOptionsBuilder() { + this.options = new OpenAiImageOptionsImpl(); + } + + public static OpenAiImageOptionsBuilder builder() { + return new OpenAiImageOptionsBuilder(); + } + + public OpenAiImageOptionsBuilder withN(Integer n) { + options.setN(n); + return this; + } + + public OpenAiImageOptionsBuilder withModel(String model) { + options.setModel(model); + return this; + } + + public OpenAiImageOptionsBuilder withQuality(String quality) { + options.setQuality(quality); + return this; + } + + public OpenAiImageOptionsBuilder withResponseFormat(String responseFormat) { + options.setResponseFormat(responseFormat); + return this; + } + + public OpenAiImageOptionsBuilder withWidth(Integer width) { + options.setWidth(width); + return this; + } + + public OpenAiImageOptionsBuilder withHeight(Integer height) { + options.setHeight(height); + return this; + } + + public OpenAiImageOptionsBuilder withStyle(String style) { + options.setStyle(style); + return this; + } + + public OpenAiImageOptionsBuilder withUser(String user) { + options.setUser(user); + return this; + } + + public OpenAiImageOptions build() { + return options; + } + +} diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageOptionsImpl.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageOptionsImpl.java new file mode 100644 index 00000000000..719c0918cb7 --- /dev/null +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageOptionsImpl.java @@ -0,0 +1,93 @@ +package org.springframework.ai.openai.api; + +public class OpenAiImageOptionsImpl implements OpenAiImageOptions { + + private Integer n; + + private String model; + + private String quality; + + private String responseFormat; + + private Integer width; + + private Integer height; + + private String style; + + private String user; + + @Override + public Integer getN() { + return n; + } + + public void setN(Integer n) { + this.n = n; + } + + @Override + public String getModel() { + return model; + } + + public void setModel(String model) { + this.model = model; + } + + @Override + public String getQuality() { + return quality; + } + + public void setQuality(String quality) { + this.quality = quality; + } + + @Override + public String getResponseFormat() { + return responseFormat; + } + + public void setResponseFormat(String responseFormat) { + this.responseFormat = responseFormat; + } + + @Override + public Integer getWidth() { + return width; + } + + public void setWidth(Integer width) { + this.width = width; + } + + @Override + public Integer getHeight() { + return height; + } + + public void setHeight(Integer height) { + this.height = height; + } + + @Override + public String getStyle() { + return style; + } + + public void setStyle(String style) { + this.style = style; + } + + @Override + public String getUser() { + return user; + } + + public void setUser(String user) { + this.user = user; + } + +} diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiGenerationMetadata.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiChatResponseMetadata.java similarity index 66% rename from models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiGenerationMetadata.java rename to models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiChatResponseMetadata.java index 23c5474b5af..ad797df518f 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiGenerationMetadata.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiChatResponseMetadata.java @@ -16,31 +16,31 @@ package org.springframework.ai.openai.metadata; -import org.springframework.ai.metadata.GenerationMetadata; -import org.springframework.ai.metadata.RateLimit; -import org.springframework.ai.metadata.Usage; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.metadata.RateLimit; +import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.lang.Nullable; import org.springframework.util.Assert; /** - * {@link GenerationMetadata} implementation for {@literal OpenAI}. + * {@link ChatResponseMetadata} implementation for {@literal OpenAI}. * * @author John Blum - * @see org.springframework.ai.metadata.GenerationMetadata - * @see org.springframework.ai.metadata.RateLimit - * @see org.springframework.ai.metadata.Usage + * @see ChatResponseMetadata + * @see RateLimit + * @see Usage * @since 0.7.0 */ -public class OpenAiGenerationMetadata implements GenerationMetadata { +public class OpenAiChatResponseMetadata implements ChatResponseMetadata { protected static final String AI_METADATA_STRING = "{ @type: %1$s, id: %2$s, usage: %3$s, rateLimit: %4$s }"; - public static OpenAiGenerationMetadata from(OpenAiApi.ChatCompletion result) { + public static OpenAiChatResponseMetadata from(OpenAiApi.ChatCompletion result) { Assert.notNull(result, "OpenAI ChatCompletionResult must not be null"); OpenAiUsage usage = OpenAiUsage.from(result.usage()); - OpenAiGenerationMetadata generationMetadata = new OpenAiGenerationMetadata(result.id(), usage); - return generationMetadata; + OpenAiChatResponseMetadata chatResponseMetadata = new OpenAiChatResponseMetadata(result.id(), usage); + return chatResponseMetadata; } private final String id; @@ -50,11 +50,11 @@ public static OpenAiGenerationMetadata from(OpenAiApi.ChatCompletion result) { private final Usage usage; - protected OpenAiGenerationMetadata(String id, OpenAiUsage usage) { + protected OpenAiChatResponseMetadata(String id, OpenAiUsage usage) { this(id, usage, null); } - protected OpenAiGenerationMetadata(String id, OpenAiUsage usage, @Nullable OpenAiRateLimit rateLimit) { + protected OpenAiChatResponseMetadata(String id, OpenAiUsage usage, @Nullable OpenAiRateLimit rateLimit) { this.id = id; this.usage = usage; this.rateLimit = rateLimit; @@ -77,7 +77,7 @@ public Usage getUsage() { return usage != null ? usage : Usage.NULL; } - public OpenAiGenerationMetadata withRateLimit(RateLimit rateLimit) { + public OpenAiChatResponseMetadata withRateLimit(RateLimit rateLimit) { this.rateLimit = rateLimit; return this; } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiImageGenerationMetadata.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiImageGenerationMetadata.java new file mode 100644 index 00000000000..8e9a5d8b0ac --- /dev/null +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiImageGenerationMetadata.java @@ -0,0 +1,54 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.metadata; + +import org.springframework.ai.image.ImageGenerationMetadata; + +import java.util.Objects; + +public class OpenAiImageGenerationMetadata implements ImageGenerationMetadata { + + private String revisedPrompt; + + public OpenAiImageGenerationMetadata(String revisedPrompt) { + this.revisedPrompt = revisedPrompt; + } + + public String getRevisedPrompt() { + return revisedPrompt; + } + + @Override + public String toString() { + return "OpenAiImageGenerationMetadata{" + "revisedPrompt='" + revisedPrompt + '\'' + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (!(o instanceof OpenAiImageGenerationMetadata that)) + return false; + return Objects.equals(revisedPrompt, that.revisedPrompt); + } + + @Override + public int hashCode() { + return Objects.hash(revisedPrompt); + } + +} diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiImageResponseMetadata.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiImageResponseMetadata.java new file mode 100644 index 00000000000..b954affa5f6 --- /dev/null +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiImageResponseMetadata.java @@ -0,0 +1,62 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.metadata; + +import org.springframework.ai.image.ImageResponseMetadata; +import org.springframework.ai.openai.api.OpenAiImageApi; +import org.springframework.util.Assert; + +import java.util.Objects; + +public class OpenAiImageResponseMetadata implements ImageResponseMetadata { + + private final Long created; + + public static OpenAiImageResponseMetadata from(OpenAiImageApi.OpenAiImageResponse openAiImageResponse) { + Assert.notNull(openAiImageResponse, "OpenAiImageResponse must not be null"); + return new OpenAiImageResponseMetadata(openAiImageResponse.created()); + } + + protected OpenAiImageResponseMetadata(Long created) { + this.created = created; + } + + @Override + public Long created() { + return this.created; + } + + @Override + public String toString() { + return "OpenAiImageResponseMetadata{" + "created=" + created + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (!(o instanceof OpenAiImageResponseMetadata that)) + return false; + return Objects.equals(created, that.created); + } + + @Override + public int hashCode() { + return Objects.hash(created); + } + +} diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiRateLimit.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiRateLimit.java index 4a0697cb397..c740c9d33fe 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiRateLimit.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiRateLimit.java @@ -18,7 +18,7 @@ import java.time.Duration; -import org.springframework.ai.metadata.RateLimit; +import org.springframework.ai.chat.metadata.RateLimit; /** * {@link RateLimit} implementation for {@literal OpenAI}. diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiUsage.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiUsage.java index 1b08d449dda..09361b2d6f9 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiUsage.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiUsage.java @@ -16,7 +16,7 @@ package org.springframework.ai.openai.metadata; -import org.springframework.ai.metadata.Usage; +import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.util.Assert; diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/support/OpenAiResponseHeaderExtractor.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/support/OpenAiResponseHeaderExtractor.java index be6eb4b4bf4..21c28652202 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/support/OpenAiResponseHeaderExtractor.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/support/OpenAiResponseHeaderExtractor.java @@ -26,7 +26,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.springframework.ai.metadata.RateLimit; +import org.springframework.ai.chat.metadata.RateLimit; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion; import org.springframework.ai.openai.metadata.OpenAiRateLimit; import org.springframework.http.ResponseEntity; diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java index 894e4e08002..b14b0ba9391 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java @@ -2,6 +2,7 @@ import org.springframework.ai.embedding.EmbeddingClient; import org.springframework.ai.openai.api.OpenAiApi; +import org.springframework.ai.openai.api.OpenAiImageApi; import org.springframework.boot.SpringBootConfiguration; import org.springframework.context.annotation.Bean; import org.springframework.util.StringUtils; @@ -11,9 +12,12 @@ public class OpenAiTestConfiguration { @Bean public OpenAiApi openAiApi() { - String apiKey = getApiKey(); - OpenAiApi openAiService = new OpenAiApi(apiKey); - return openAiService; + return new OpenAiApi(getApiKey()); + } + + @Bean + public OpenAiImageApi openAiImageApi() { + return new OpenAiImageApi(getApiKey()); } private String getApiKey() { @@ -32,6 +36,13 @@ public OpenAiChatClient openAiChatClient(OpenAiApi api) { return openAiChatClient; } + @Bean + public OpenAiImageClient openAiImageClient(OpenAiImageApi imageApi) { + OpenAiImageClient openAiImageClient = new OpenAiImageClient(imageApi); + // openAiImageClient.setModel("foobar"); + return openAiImageClient; + } + @Bean public EmbeddingClient openAiEmbeddingClient(OpenAiApi api) { return new OpenAiEmbeddingClient(api); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/acme/AcmeIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/acme/AcmeIT.java index 92fc5bd357e..9bdb44d8ec2 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/acme/AcmeIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/acme/AcmeIT.java @@ -15,10 +15,10 @@ import org.springframework.ai.openai.OpenAiChatClient; import org.springframework.ai.openai.OpenAiEmbeddingClient; import org.springframework.ai.openai.testutils.AbstractIT; -import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.SystemPromptTemplate; -import org.springframework.ai.prompt.messages.Message; -import org.springframework.ai.prompt.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.reader.JsonReader; import org.springframework.ai.transformer.splitter.TokenTextSplitter; import org.springframework.ai.vectorstore.SimpleVectorStore; @@ -90,10 +90,10 @@ void acmeChain() { // Create the prompt ad-hoc for now, need to put in system message and user // message via ChatPromptTemplate or some other message building mechanic; - logger.info("Asking AI model to reply to question."); + logger.info("Asking AI generative to reply to question."); Prompt prompt = new Prompt(List.of(systemMessage, userMessage)); logger.info("AI responded."); - ChatResponse response = chatClient.generate(prompt); + ChatResponse response = chatClient.call(prompt); evaluateQuestionAndAnswer(userQuery, response, true); } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatClientIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatClientIT.java index 9697a8b5f23..aa35ba9e04b 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatClientIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatClientIT.java @@ -10,16 +10,17 @@ import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.Generation; +import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.openai.OpenAiTestConfiguration; import org.springframework.ai.openai.testutils.AbstractIT; import org.springframework.ai.parser.BeanOutputParser; import org.springframework.ai.parser.ListOutputParser; import org.springframework.ai.parser.MapOutputParser; -import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.PromptTemplate; -import org.springframework.ai.prompt.SystemPromptTemplate; -import org.springframework.ai.prompt.messages.Message; -import org.springframework.ai.prompt.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.core.convert.support.DefaultConversionService; @@ -41,9 +42,9 @@ void roleTest() { SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - ChatResponse response = openAiChatClient.generate(prompt); - assertThat(response.getGenerations()).hasSize(1); - assertThat(response.getGenerations().get(0).getContent()).contains("Blackbeard"); + ChatResponse response = openAiChatClient.call(prompt); + assertThat(response.getResults()).hasSize(1); + assertThat(response.getResults().get(0).getOutput().getContent()).contains("Blackbeard"); // needs fine tuning... evaluateQuestionAndAnswer(request, response, false); } @@ -60,9 +61,9 @@ void outputParser() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "ice cream flavors", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = this.openAiChatClient.generate(prompt).getGeneration(); + Generation generation = this.openAiChatClient.call(prompt).getResult(); - List list = outputParser.parse(generation.getContent()); + List list = outputParser.parse(generation.getOutput().getContent()); assertThat(list).hasSize(5); } @@ -79,9 +80,9 @@ void mapOutputParser() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = openAiChatClient.generate(prompt).getGeneration(); + Generation generation = openAiChatClient.call(prompt).getResult(); - Map result = outputParser.parse(generation.getContent()); + Map result = outputParser.parse(generation.getOutput().getContent()); assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); } @@ -98,9 +99,9 @@ void beanOutputParser() { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = openAiChatClient.generate(prompt).getGeneration(); + Generation generation = openAiChatClient.call(prompt).getResult(); - ActorsFilms actorsFilms = outputParser.parse(generation.getContent()); + ActorsFilms actorsFilms = outputParser.parse(generation.getOutput().getContent()); } record ActorsFilmsRecord(String actor, List movies) { @@ -118,9 +119,9 @@ void beanOutputParserRecords() { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = openAiChatClient.generate(prompt).getGeneration(); + Generation generation = openAiChatClient.call(prompt).getResult(); - ActorsFilmsRecord actorsFilms = outputParser.parse(generation.getContent()); + ActorsFilmsRecord actorsFilms = outputParser.parse(generation.getOutput().getContent()); System.out.println(actorsFilms); assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); assertThat(actorsFilms.movies()).hasSize(5); @@ -143,9 +144,10 @@ void beanStreamOutputParserRecords() { .collectList() .block() .stream() - .map(ChatResponse::getGenerations) + .map(ChatResponse::getResults) .flatMap(List::stream) - .map(Generation::getContent) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) .collect(Collectors.joining()); ActorsFilmsRecord actorsFilms = outputParser.parse(generationTextFromStream); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatClientWithGenerationMetadataTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatClientWithChatResponseMetadataTests.java similarity index 82% rename from models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatClientWithGenerationMetadataTests.java rename to models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatClientWithChatResponseMetadataTests.java index cd091b92261..4d34cfbd3bc 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatClientWithGenerationMetadataTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatClientWithChatResponseMetadataTests.java @@ -22,15 +22,11 @@ import org.junit.jupiter.api.Test; import org.springframework.ai.chat.ChatResponse; -import org.springframework.ai.metadata.ChoiceMetadata; -import org.springframework.ai.metadata.GenerationMetadata; -import org.springframework.ai.metadata.PromptMetadata; -import org.springframework.ai.metadata.RateLimit; -import org.springframework.ai.metadata.Usage; +import org.springframework.ai.chat.metadata.*; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.OpenAiChatClient; import org.springframework.ai.openai.metadata.support.OpenAiApiResponseHeaders; -import org.springframework.ai.prompt.Prompt; +import org.springframework.ai.chat.prompt.Prompt; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.autoconfigure.web.client.RestClientTest; @@ -52,8 +48,8 @@ * @author Christian Tzolov * @since 0.7.0 */ -@RestClientTest(OpenAiChatClientWithGenerationMetadataTests.Config.class) -public class OpenAiChatClientWithGenerationMetadataTests { +@RestClientTest(OpenAiChatClientWithChatResponseMetadataTests.Config.class) +public class OpenAiChatClientWithChatResponseMetadataTests { private static String TEST_API_KEY = "sk-1234567890"; @@ -75,22 +71,22 @@ void aiResponseContainsAiMetadata() { Prompt prompt = new Prompt("Reach for the sky."); - ChatResponse response = this.openAiChatClient.generate(prompt); + ChatResponse response = this.openAiChatClient.call(prompt); assertThat(response).isNotNull(); - GenerationMetadata generationMetadata = response.getGenerationMetadata(); + ChatResponseMetadata chatResponseMetadata = response.getMetadata(); - assertThat(generationMetadata).isNotNull(); + assertThat(chatResponseMetadata).isNotNull(); - Usage usage = generationMetadata.getUsage(); + Usage usage = chatResponseMetadata.getUsage(); assertThat(usage).isNotNull(); assertThat(usage.getPromptTokens()).isEqualTo(9L); assertThat(usage.getGenerationTokens()).isEqualTo(12L); assertThat(usage.getTotalTokens()).isEqualTo(21L); - RateLimit rateLimit = generationMetadata.getRateLimit(); + RateLimit rateLimit = chatResponseMetadata.getRateLimit(); Duration expectedRequestsReset = Duration.ofDays(2L) .plus(Duration.ofHours(16L)) @@ -109,16 +105,16 @@ void aiResponseContainsAiMetadata() { assertThat(rateLimit.getTokensRemaining()).isEqualTo(112_358L); assertThat(rateLimit.getTokensReset()).isEqualTo(expectedTokensReset); - PromptMetadata promptMetadata = response.getPromptMetadata(); + PromptMetadata promptMetadata = response.getMetadata().getPromptMetadata(); assertThat(promptMetadata).isNotNull(); assertThat(promptMetadata).isEmpty(); - response.getGenerations().forEach(generation -> { - ChoiceMetadata choiceMetadata = generation.getChoiceMetadata(); - assertThat(choiceMetadata).isNotNull(); - assertThat(choiceMetadata.getFinishReason()).isEqualTo("STOP"); - assertThat(choiceMetadata.getContentFilterMetadata()).isNull(); + response.getResults().forEach(generation -> { + ChatGenerationMetadata chatGenerationMetadata = generation.getMetadata(); + assertThat(chatGenerationMetadata).isNotNull(); + assertThat(chatGenerationMetadata.getFinishReason()).isEqualTo("STOP"); + assertThat(chatGenerationMetadata.getContentFilterMetadata()).isNull(); }); } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/embedding/EmbeddingIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/embedding/EmbeddingIT.java index b2515a6a5c5..4c5d6f9af0c 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/embedding/EmbeddingIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/embedding/EmbeddingIT.java @@ -38,7 +38,7 @@ void simpleEmbedding() { EmbeddingResponse embeddingResponse = embeddingClient.embedForResponse(List.of("Hello World")); assertThat(embeddingResponse.getData()).hasSize(1); assertThat(embeddingResponse.getData().get(0).getEmbedding()).isNotEmpty(); - assertThat(embeddingResponse.getMetadata()).containsEntry("model", "text-embedding-ada-002-v2"); + assertThat(embeddingResponse.getMetadata()).containsEntry("generative", "text-embedding-ada-002-v2"); assertThat(embeddingResponse.getMetadata()).containsEntry("total-tokens", 2); assertThat(embeddingResponse.getMetadata()).containsEntry("prompt-tokens", 2); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageClientIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageClientIT.java new file mode 100644 index 00000000000..bcc29dc3d66 --- /dev/null +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageClientIT.java @@ -0,0 +1,60 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.openai.image; + +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.image.*; +import org.springframework.ai.openai.OpenAiTestConfiguration; +import org.springframework.ai.openai.metadata.OpenAiImageGenerationMetadata; +import org.springframework.ai.openai.testutils.AbstractIT; +import org.springframework.boot.test.context.SpringBootTest; + +import static org.assertj.core.api.Assertions.assertThat; + +@SpringBootTest(classes = OpenAiTestConfiguration.class) +@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+") +public class OpenAiImageClientIT extends AbstractIT { + + @Test + void imageAsUrlTest() { + var options = ImageOptionsBuilder.builder().withHeight(256).withWidth(256).build(); + + ImagePrompt imagePrompt = new ImagePrompt("Create an image of a mini golden doodle dog.", options); + + ImageResponse imageResponse = openaiImageClient.call(imagePrompt); + + assertThat(imageResponse.getResults()).hasSize(1); + + ImageResponseMetadata imageResponseMetadata = imageResponse.getMetadata(); + assertThat(imageResponseMetadata.created()).isPositive(); + + var generation = imageResponse.getResult(); + Image image = generation.getOutput(); + assertThat(image.getUrl()).isNotEmpty(); + assertThat(image.getB64Json()).isNull(); + + var imageGenerationMetadata = generation.getMetadata(); + Assertions.assertThat(imageGenerationMetadata).isInstanceOf(OpenAiImageGenerationMetadata.class); + + OpenAiImageGenerationMetadata openAiImageGenerationMetadata = (OpenAiImageGenerationMetadata) imageGenerationMetadata; + + assertThat(openAiImageGenerationMetadata).isNotNull(); + assertThat(openAiImageGenerationMetadata.getRevisedPrompt()).isNull(); + } + +} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageClientWithImageResponseMetadataTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageClientWithImageResponseMetadataTests.java new file mode 100644 index 00000000000..5f660a1f8b0 --- /dev/null +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageClientWithImageResponseMetadataTests.java @@ -0,0 +1,142 @@ +/* + * Copyright 2023-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.image; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; +import org.springframework.ai.image.ImageGeneration; +import org.springframework.ai.image.ImagePrompt; +import org.springframework.ai.image.ImageResponse; +import org.springframework.ai.image.ImageResponseMetadata; +import org.springframework.ai.openai.OpenAiImageClient; +import org.springframework.ai.openai.api.OpenAiImageApi; +import org.springframework.ai.openai.metadata.support.OpenAiApiResponseHeaders; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.autoconfigure.web.client.RestClientTest; +import org.springframework.context.annotation.Bean; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; +import org.springframework.test.web.client.MockRestServiceServer; +import org.springframework.web.client.RestClient; + +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.test.web.client.match.MockRestRequestMatchers.*; +import static org.springframework.test.web.client.response.MockRestResponseCreators.withSuccess; + +/** + * @author John Blum + * @author Christian Tzolov + * @since 0.7.0 + */ +@RestClientTest(OpenAiImageClientWithImageResponseMetadataTests.Config.class) +public class OpenAiImageClientWithImageResponseMetadataTests { + + private static String TEST_API_KEY = "sk-1234567890"; + + @Autowired + private OpenAiImageClient openAiImageClient; + + @Autowired + private MockRestServiceServer server; + + @AfterEach + void resetMockServer() { + server.reset(); + } + + @Test + void aiResponseContainsImageResponseMetadata() { + + prepareMock(); + + ImagePrompt prompt = new ImagePrompt("Create an image of a mini golden doodle dog."); + + ImageResponse response = this.openAiImageClient.call(prompt); + + assertThat(response).isNotNull(); + List imageGenerations = response.getResults(); + assertThat(imageGenerations).isNotNull(); + assertThat(imageGenerations).hasSize(2); + + ImageResponseMetadata imageResponseMetadata = response.getMetadata(); + + assertThat(imageResponseMetadata).isNotNull(); + + Long created = imageResponseMetadata.created(); + + assertThat(created).isNotNull(); + assertThat(created).isEqualTo(1589478378); + + ImageResponseMetadata responseMetadata = response.getMetadata(); + + assertThat(responseMetadata).isNotNull(); + + } + + private void prepareMock() { + + HttpHeaders httpHeaders = new HttpHeaders(); + httpHeaders.set(OpenAiApiResponseHeaders.REQUESTS_LIMIT_HEADER.getName(), "4000"); + httpHeaders.set(OpenAiApiResponseHeaders.REQUESTS_REMAINING_HEADER.getName(), "999"); + httpHeaders.set(OpenAiApiResponseHeaders.REQUESTS_RESET_HEADER.getName(), "2d16h15m29s"); + httpHeaders.set(OpenAiApiResponseHeaders.TOKENS_LIMIT_HEADER.getName(), "725000"); + httpHeaders.set(OpenAiApiResponseHeaders.TOKENS_REMAINING_HEADER.getName(), "112358"); + httpHeaders.set(OpenAiApiResponseHeaders.TOKENS_RESET_HEADER.getName(), "27h55s451ms"); + + server.expect(requestTo("v1/images/generations")) + .andExpect(method(HttpMethod.POST)) + .andExpect(header(HttpHeaders.AUTHORIZATION, "Bearer " + TEST_API_KEY)) + .andRespond(withSuccess(getJson(), MediaType.APPLICATION_JSON).headers(httpHeaders)); + + } + + private String getJson() { + return """ + { + "created": 1589478378, + "data": [ + { + "url": "https://upload.wikimedia.org/wikipedia/commons/4/4e/Mini_Golden_Doodle.jpg" + }, + { + "url": "https://upload.wikimedia.org/wikipedia/commons/8/85/Goldendoodle_puppy_Marty.jpg" + } + ] + } + """; + } + + @SpringBootConfiguration + static class Config { + + @Bean + public OpenAiImageApi imageGenerationApi(RestClient.Builder builder) { + return new OpenAiImageApi("", TEST_API_KEY, builder); + } + + @Bean + public OpenAiImageClient openAiImageClient(OpenAiImageApi openAiImageApi) { + return new OpenAiImageClient(openAiImageApi); + } + + } + +} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/testutils/AbstractIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/testutils/AbstractIT.java index a41ca889c16..2e878f802b7 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/testutils/AbstractIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/testutils/AbstractIT.java @@ -9,10 +9,11 @@ import org.springframework.ai.chat.ChatClient; import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.StreamingChatClient; -import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.PromptTemplate; -import org.springframework.ai.prompt.messages.Message; -import org.springframework.ai.prompt.messages.SystemMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.image.ImageClient; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.core.io.Resource; @@ -27,6 +28,9 @@ public abstract class AbstractIT { @Autowired protected ChatClient openAiChatClient; + @Autowired + protected ImageClient openaiImageClient; + @Autowired protected StreamingChatClient openStreamingChatClient; @@ -44,7 +48,7 @@ public abstract class AbstractIT { protected void evaluateQuestionAndAnswer(String question, ChatResponse response, boolean factBased) { assertThat(response).isNotNull(); - String answer = response.getGeneration().getContent(); + String answer = response.getResult().getOutput().getContent(); logger.info("Question: " + question); logger.info("Answer:" + answer); PromptTemplate userPromptTemplate = new PromptTemplate(userEvaluatorResource, @@ -58,12 +62,12 @@ protected void evaluateQuestionAndAnswer(String question, ChatResponse response, } Message userMessage = userPromptTemplate.createMessage(); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - String yesOrNo = openAiChatClient.generate(prompt).getGeneration().getContent(); + String yesOrNo = openAiChatClient.call(prompt).getResult().getOutput().getContent(); logger.info("Is Answer related to question: " + yesOrNo); if (yesOrNo.equalsIgnoreCase("no")) { SystemMessage notRelatedSystemMessage = new SystemMessage(qaEvaluatorNotRelatedResource); prompt = new Prompt(List.of(userMessage, notRelatedSystemMessage)); - String reasonForFailure = openAiChatClient.generate(prompt).getGeneration().getContent(); + String reasonForFailure = openAiChatClient.call(prompt).getResult().getOutput().getContent(); fail(reasonForFailure); } else { diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/transformer/MetadataTransformerIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/transformer/MetadataTransformerIT.java index a4e79b3ec40..120b32a7be2 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/transformer/MetadataTransformerIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/transformer/MetadataTransformerIT.java @@ -65,7 +65,7 @@ public class MetadataTransformerIT { Document document2 = new Document( "The Spring Framework is divided into modules. Applications can choose which modules" - + " they need. At the heart are the modules of the core container, including a configuration model and a " + + " they need. At the heart are the modules of the core container, including a configuration generative and a " + "dependency injection mechanism. Beyond that, the Spring Framework provides foundational support " + " for different application architectures, including messaging, transactional data and persistence, " + "and web. It also includes the Servlet-based Spring MVC web framework and, in parallel, the Spring " diff --git a/models/spring-ai-stabilityai/pom.xml b/models/spring-ai-stabilityai/pom.xml new file mode 100644 index 00000000000..d189362de22 --- /dev/null +++ b/models/spring-ai-stabilityai/pom.xml @@ -0,0 +1,60 @@ + + + 4.0.0 + + org.springframework.ai + spring-ai + 0.8.0-SNAPSHOT + ../../pom.xml + + spring-ai-stability-ai + jar + Spring AI Stability AI + Stability AI support + https://github.com/spring-projects/spring-ai + + + https://github.com/spring-projects/spring-ai + git://github.com/spring-projects/spring-ai.git + git@github.com:spring-projects/spring-ai.git + + + + + + + org.springframework.ai + spring-ai-core + ${project.parent.version} + + + + org.springframework + spring-web + ${spring-framework.version} + + + + + + org.springframework + spring-context-support + + + + org.springframework.boot + spring-boot-starter-logging + + + + + org.springframework.ai + spring-ai-test + ${project.version} + test + + + + + diff --git a/models/spring-ai-stabilityai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageClient.java b/models/spring-ai-stabilityai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageClient.java new file mode 100644 index 00000000000..edd67c532a1 --- /dev/null +++ b/models/spring-ai-stabilityai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageClient.java @@ -0,0 +1,158 @@ +package org.springframework.ai.stabilityai; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.image.*; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.stabilityai.api.StabilityAiApi; +import org.springframework.ai.stabilityai.api.StabilityAiImageOptions; +import org.springframework.ai.stabilityai.api.StabilityAiImageOptionsBuilder; +import org.springframework.ai.stabilityai.api.StabilityAiImageOptionsImpl; +import org.springframework.util.Assert; + +import java.util.List; +import java.util.stream.Collectors; + +public class StabilityAiImageClient implements ImageClient { + + private final Logger logger = LoggerFactory.getLogger(getClass()); + + private StabilityAiImageOptions options; + + private final StabilityAiApi stabilityAiApi; + + public StabilityAiImageClient(StabilityAiApi stabilityAiApi) { + this(stabilityAiApi, StabilityAiImageOptionsBuilder.builder().build()); + } + + public StabilityAiImageClient(StabilityAiApi stabilityAiApi, StabilityAiImageOptions options) { + Assert.notNull(stabilityAiApi, "StabilityAiApi must not be null"); + Assert.notNull(options, "StabilityAiImageOptions must not be null"); + this.stabilityAiApi = stabilityAiApi; + this.options = options; + } + + public StabilityAiImageOptions getOptions() { + return options; + } + + /** + * Calls the StabilityAiImageClient with the given StabilityAiImagePrompt and returns + * the ImageResponse. This overloaded call method lets you pass the full set of Prompt + * instructions that StabilityAI supports. + * @param imagePrompt the StabilityAiImagePrompt containing the prompt and image model + * options + * @return the ImageResponse generated by the StabilityAiImageClient + */ + public ImageResponse call(ImagePrompt imagePrompt) { + + ImageOptions runtimeOptions = imagePrompt.getOptions(); + + // Merge the runtime options passed via the prompt with the StabilityAiImageClient + // options configured via Autoconfiguration. + // Runtime options overwrite StabilityAiImageClient options + StabilityAiImageOptions optionsToUse = ModelOptionsUtils.merge(runtimeOptions, this.options, + StabilityAiImageOptionsImpl.class); + + // Copy the org.springframework.ai.model derived ImagePrompt and ImageOptions data + // types to the data types used in StabilityAiApi + StabilityAiApi.GenerateImageRequest generateImageRequest = getGenerateImageRequest(imagePrompt, optionsToUse); + + // Make the request + StabilityAiApi.GenerateImageResponse generateImageResponse = this.stabilityAiApi + .generateImage(generateImageRequest); + + // Convert to org.springframework.ai.model derived ImageResponse data type + return convertResponse(generateImageResponse); + + } + + private static StabilityAiApi.GenerateImageRequest getGenerateImageRequest(ImagePrompt stabilityAiImagePrompt, + StabilityAiImageOptions optionsToUse) { + StabilityAiApi.GenerateImageRequest.Builder builder = new StabilityAiApi.GenerateImageRequest.Builder(); + StabilityAiApi.GenerateImageRequest generateImageRequest = builder + .withTextPrompts(stabilityAiImagePrompt.getInstructions() + .stream() + .map(message -> new StabilityAiApi.GenerateImageRequest.TextPrompts(message.getText(), + message.getWeight())) + .collect(Collectors.toList())) + .withHeight(optionsToUse.getHeight()) + .withWidth(optionsToUse.getWidth()) + .withCfgScale(optionsToUse.getCfgScale()) + .withClipGuidancePreset(optionsToUse.getClipGuidancePreset()) + .withSampler(optionsToUse.getSampler()) + .withSamples(optionsToUse.getSamples()) + .withSeed(optionsToUse.getSeed()) + .withSteps(optionsToUse.getSteps()) + .withStylePreset(optionsToUse.getStylePreset()) + .build(); + return generateImageRequest; + } + + private ImageResponse convertResponse(StabilityAiApi.GenerateImageResponse generateImageResponse) { + List imageGenerationList = generateImageResponse.artifacts().stream().map(entry -> { + return new ImageGeneration(new Image(null, entry.base64()), + new StabilityAiImageGenerationMetadata(entry.finishReason(), entry.seed())); + }).toList(); + + return new ImageResponse(imageGenerationList, ImageResponseMetadata.NULL); + } + + private StabilityAiImageOptions convertOptions(ImageOptions runtimeOptions) { + StabilityAiImageOptionsBuilder builder = StabilityAiImageOptionsBuilder.builder(); + if (runtimeOptions == null) { + return builder.build(); + } + if (runtimeOptions.getN() != null) { + builder.withN(runtimeOptions.getN()); + } + if (runtimeOptions.getModel() != null) { + builder.withModel(runtimeOptions.getModel()); + } + if (runtimeOptions.getResponseFormat() != null) { + builder.withResponseFormat(runtimeOptions.getResponseFormat()); + } + if (runtimeOptions.getWidth() != null) { + builder.withWidth(runtimeOptions.getWidth()); + } + if (runtimeOptions.getHeight() != null) { + builder.withHeight(runtimeOptions.getHeight()); + } + if (runtimeOptions instanceof StabilityAiImageOptions) { + StabilityAiImageOptions stabilityAiImageOptions = (StabilityAiImageOptions) runtimeOptions; + if (stabilityAiImageOptions.getCfgScale() != null) { + builder.withCfgScale(stabilityAiImageOptions.getCfgScale()); + } + if (stabilityAiImageOptions.getClipGuidancePreset() != null) { + builder.withClipGuidancePreset(stabilityAiImageOptions.getClipGuidancePreset()); + } + if (stabilityAiImageOptions.getSampler() != null) { + builder.withSampler(stabilityAiImageOptions.getSampler()); + } + if (stabilityAiImageOptions.getSeed() != null) { + builder.withSeed(stabilityAiImageOptions.getSeed()); + } + if (stabilityAiImageOptions.getSteps() != null) { + builder.withSteps(stabilityAiImageOptions.getSteps()); + } + if (stabilityAiImageOptions.getStylePreset() != null) { + builder.withStylePreset(stabilityAiImageOptions.getStylePreset()); + } + } + return builder.build(); + } + + private ImagePrompt createUpdatedPrompt(ImagePrompt prompt) { + ImageOptions runtimeImageModelOptions = prompt.getOptions(); + ImageOptionsBuilder imageOptionsBuilder = ImageOptionsBuilder.builder(); + + if (runtimeImageModelOptions != null) { + if (runtimeImageModelOptions.getModel() != null) { + imageOptionsBuilder.withModel(runtimeImageModelOptions.getModel()); + } + } + ImageOptions updatedImageModelOptions = imageOptionsBuilder.build(); + return new ImagePrompt(prompt.getInstructions(), updatedImageModelOptions); + } + +} diff --git a/models/spring-ai-stabilityai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageGenerationMetadata.java b/models/spring-ai-stabilityai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageGenerationMetadata.java new file mode 100644 index 00000000000..8eeb4be9ae5 --- /dev/null +++ b/models/spring-ai-stabilityai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageGenerationMetadata.java @@ -0,0 +1,45 @@ +package org.springframework.ai.stabilityai; + +import org.springframework.ai.image.ImageGenerationMetadata; + +import java.util.Objects; + +public class StabilityAiImageGenerationMetadata implements ImageGenerationMetadata { + + private String finishReason; + + private Long seed; + + public StabilityAiImageGenerationMetadata(String finishReason, Long seed) { + this.finishReason = finishReason; + this.seed = seed; + } + + public String getFinishReason() { + return finishReason; + } + + public Long getSeed() { + return seed; + } + + @Override + public String toString() { + return "StabilityAiImageGenerationMetadata{" + "finishReason='" + finishReason + '\'' + ", seed=" + seed + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (!(o instanceof StabilityAiImageGenerationMetadata that)) + return false; + return Objects.equals(finishReason, that.finishReason) && Objects.equals(seed, that.seed); + } + + @Override + public int hashCode() { + return Objects.hash(finishReason, seed); + } + +} diff --git a/models/spring-ai-stabilityai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiApi.java b/models/spring-ai-stabilityai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiApi.java new file mode 100644 index 00000000000..9e15cec7f60 --- /dev/null +++ b/models/spring-ai-stabilityai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiApi.java @@ -0,0 +1,212 @@ +package org.springframework.ai.stabilityai.api; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.http.client.ClientHttpResponse; +import org.springframework.util.Assert; +import org.springframework.web.client.ResponseErrorHandler; +import org.springframework.web.client.RestClient; + +import java.io.IOException; +import java.util.List; +import java.util.function.Consumer; + +public class StabilityAiApi { + + public static final String DEFAULT_IMAGE_MODEL = "stable-diffusion-v1-6"; + + public static final String DEFAULT_BASE_URL = "https://api.stability.ai/v1"; + + private final RestClient restClient; + + private final String apiKey; + + private final String model; + + /** + * Create a new StabilityAI API. + * @param apiKey StabilityAI apiKey. + */ + public StabilityAiApi(String apiKey) { + this(apiKey, DEFAULT_IMAGE_MODEL, DEFAULT_BASE_URL, RestClient.builder()); + } + + public StabilityAiApi(String apiKey, String model) { + this(apiKey, model, DEFAULT_BASE_URL, RestClient.builder()); + } + + public StabilityAiApi(String apiKey, String model, String baseUrl) { + this(apiKey, model, baseUrl, RestClient.builder()); + } + + /** + * Create a new StabilityAI API. + * @param apiKey StabilityAI apiKey. + * @param model StabilityAI model. + * @param baseUrl api base URL. + * @param restClientBuilder RestClient builder. + */ + public StabilityAiApi(String apiKey, String model, String baseUrl, RestClient.Builder restClientBuilder) { + + this.model = model; + this.apiKey = apiKey; + + Consumer jsonContentHeaders = headers -> { + headers.setBearerAuth(apiKey); + headers.setAccept(List.of(MediaType.APPLICATION_JSON)); // base64 in JSON + + // metadata or return + // image in bytes. + headers.setContentType(MediaType.APPLICATION_JSON); + }; + + ResponseErrorHandler responseErrorHandler = new ResponseErrorHandler() { + @Override + public boolean hasError(ClientHttpResponse response) throws IOException { + return response.getStatusCode().isError(); + } + + @Override + public void handleError(ClientHttpResponse response) throws IOException { + if (response.getStatusCode().isError()) { + throw new RuntimeException(String.format("%s - %s", response.getStatusCode().value(), + new ObjectMapper().readValue(response.getBody(), ResponseError.class))); + } + } + }; + + this.restClient = restClientBuilder.baseUrl(baseUrl) + .defaultHeaders(jsonContentHeaders) + .defaultStatusHandler(responseErrorHandler) + .build(); + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + public record ResponseError(@JsonProperty("id") String id, @JsonProperty("name") String name, + @JsonProperty("message") String message + + ) { + + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + public record GenerateImageRequest(@JsonProperty("text_prompts") List textPrompts, + @JsonProperty("height") Integer height, @JsonProperty("width") Integer width, + @JsonProperty("cfg_scale") Float cfgScale, @JsonProperty("clip_guidance_preset") String clipGuidancePreset, + @JsonProperty("sampler") String sampler, @JsonProperty("samples") Integer samples, + @JsonProperty("seed") Long seed, @JsonProperty("steps") Integer steps, + @JsonProperty("style_present") String stylePreset) { + + @JsonInclude(JsonInclude.Include.NON_NULL) + public record TextPrompts(@JsonProperty("text") String text, @JsonProperty("weight") Float weight) { + + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + List textPrompts; + + Integer height; + + Integer width; + + Float cfgScale; + + String clipGuidancePreset; + + String sampler; + + Integer samples; + + Long seed; + + Integer steps; + + String stylePreset; + + public Builder() { + + } + + public Builder withTextPrompts(List textPrompts) { + this.textPrompts = textPrompts; + return this; + } + + public Builder withHeight(Integer height) { + this.height = height; + return this; + } + + public Builder withWidth(Integer width) { + this.width = width; + return this; + } + + public Builder withCfgScale(Float cfgScale) { + this.cfgScale = cfgScale; + return this; + } + + public Builder withClipGuidancePreset(String clipGuidancePreset) { + this.clipGuidancePreset = clipGuidancePreset; + return this; + } + + public Builder withSampler(String sampler) { + this.sampler = sampler; + return this; + } + + public Builder withSamples(Integer samples) { + this.samples = samples; + return this; + } + + public Builder withSeed(Long seed) { + this.seed = seed; + return this; + } + + public Builder withSteps(Integer steps) { + this.steps = steps; + return this; + } + + public Builder withStylePreset(String stylePreset) { + this.stylePreset = stylePreset; + return this; + } + + public GenerateImageRequest build() { + return new GenerateImageRequest(textPrompts, height, width, cfgScale, clipGuidancePreset, sampler, + samples, seed, steps, stylePreset); + } + + } + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + public record GenerateImageResponse(@JsonProperty("result") String result, + @JsonProperty("artifacts") List artifacts) { + public record Artifacts(@JsonProperty("seed") long seed, @JsonProperty("base64") String base64, + @JsonProperty("finishReason") String finishReason) { + } + } + + public GenerateImageResponse generateImage(GenerateImageRequest request) { + Assert.notNull(request, "The request body can not be null."); + return this.restClient.post() + .uri("/generation/{model}/text-to-image", this.model) + .body(request) + .retrieve() + .body(GenerateImageResponse.class); + } + +} diff --git a/models/spring-ai-stabilityai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiImageOptions.java b/models/spring-ai-stabilityai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiImageOptions.java new file mode 100644 index 00000000000..358cfd8bc70 --- /dev/null +++ b/models/spring-ai-stabilityai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiImageOptions.java @@ -0,0 +1,23 @@ +package org.springframework.ai.stabilityai.api; + +import org.springframework.ai.image.ImageOptions; + +public interface StabilityAiImageOptions extends ImageOptions { + + Float getCfgScale(); + + String getClipGuidancePreset(); + + String getSampler(); + + Integer getSamples(); + + Long getSeed(); + + Integer getSteps(); + + String getStylePreset(); + + // extras json object... + +} diff --git a/models/spring-ai-stabilityai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiImageOptionsBuilder.java b/models/spring-ai-stabilityai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiImageOptionsBuilder.java new file mode 100644 index 00000000000..17f5bb9beb6 --- /dev/null +++ b/models/spring-ai-stabilityai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiImageOptionsBuilder.java @@ -0,0 +1,79 @@ +package org.springframework.ai.stabilityai.api; + +public class StabilityAiImageOptionsBuilder { + + private StabilityAiImageOptionsImpl options; + + private StabilityAiImageOptionsBuilder() { + options = new StabilityAiImageOptionsImpl(); + } + + public static StabilityAiImageOptionsBuilder builder() { + return new StabilityAiImageOptionsBuilder(); + } + + public StabilityAiImageOptionsBuilder withN(Integer n) { + options.setN(n); + return this; + } + + public StabilityAiImageOptionsBuilder withModel(String model) { + options.setModel(model); + return this; + } + + public StabilityAiImageOptionsBuilder withWidth(Integer width) { + options.setWidth(width); + return this; + } + + public StabilityAiImageOptionsBuilder withHeight(Integer height) { + options.setHeight(height); + return this; + } + + public StabilityAiImageOptionsBuilder withResponseFormat(String responseFormat) { + options.setResponseFormat(responseFormat); + return this; + } + + public StabilityAiImageOptionsBuilder withCfgScale(Float cfgScale) { + options.setCfgScale(cfgScale); + return this; + } + + public StabilityAiImageOptionsBuilder withClipGuidancePreset(String clipGuidancePreset) { + options.setClipGuidancePreset(clipGuidancePreset); + return this; + } + + public StabilityAiImageOptionsBuilder withSampler(String sampler) { + options.setSampler(sampler); + return this; + } + + public StabilityAiImageOptionsBuilder withSeed(Long seed) { + options.setSeed(seed); + return this; + } + + public StabilityAiImageOptionsBuilder withSteps(Integer steps) { + options.setSteps(steps); + return this; + } + + public StabilityAiImageOptionsBuilder withSamples(Integer samples) { + options.setSamples(samples); + return this; + } + + public StabilityAiImageOptionsBuilder withStylePreset(String stylePreset) { + options.setStylePreset(stylePreset); + return this; + } + + public StabilityAiImageOptions build() { + return options; + } + +} diff --git a/models/spring-ai-stabilityai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiImageOptionsImpl.java b/models/spring-ai-stabilityai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiImageOptionsImpl.java new file mode 100644 index 00000000000..0f7213ea860 --- /dev/null +++ b/models/spring-ai-stabilityai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiImageOptionsImpl.java @@ -0,0 +1,140 @@ +package org.springframework.ai.stabilityai.api; + +public class StabilityAiImageOptionsImpl implements StabilityAiImageOptions { + + private Integer n; + + private String model; + + private Integer width; + + private Integer height; + + private String responseFormat; + + private Float cfgScale; + + private String clipGuidancePreset; + + private String sampler; + + private Integer samples; + + private Long seed; + + private Integer steps; + + private String stylePreset; + + public StabilityAiImageOptionsImpl() { + } + + @Override + public Integer getN() { + return this.n; + } + + public void setN(Integer n) { + this.n = n; + } + + @Override + public String getModel() { + return this.model; + } + + public void setModel(String model) { + this.model = model; + } + + @Override + public Integer getWidth() { + return this.width; + } + + public void setWidth(Integer width) { + this.width = width; + } + + @Override + public Integer getHeight() { + return this.height; + } + + public void setHeight(Integer height) { + this.height = height; + } + + @Override + public String getResponseFormat() { + return this.responseFormat; + } + + public void setResponseFormat(String responseFormat) { + this.responseFormat = responseFormat; + } + + @Override + public Float getCfgScale() { + return this.cfgScale; + } + + public void setCfgScale(Float cfgScale) { + this.cfgScale = cfgScale; + } + + @Override + public String getClipGuidancePreset() { + return this.clipGuidancePreset; + } + + public void setClipGuidancePreset(String clipGuidancePreset) { + this.clipGuidancePreset = clipGuidancePreset; + } + + @Override + public String getSampler() { + return this.sampler; + } + + public void setSampler(String sampler) { + this.sampler = sampler; + } + + @Override + public Integer getSamples() { + return this.samples; + } + + public void setSamples(Integer samples) { + this.samples = samples; + } + + @Override + public Long getSeed() { + return this.seed; + } + + public void setSeed(Long seed) { + this.seed = seed; + } + + @Override + public Integer getSteps() { + return this.steps; + } + + public void setSteps(Integer steps) { + this.steps = steps; + } + + @Override + public String getStylePreset() { + return this.stylePreset; + } + + public void setStylePreset(String stylePreset) { + this.stylePreset = stylePreset; + } + +} diff --git a/models/spring-ai-stabilityai/src/test/java/org/springframework/ai/stabilityai/StabilityAiApiIT.java b/models/spring-ai-stabilityai/src/test/java/org/springframework/ai/stabilityai/StabilityAiApiIT.java new file mode 100644 index 00000000000..d70f0f071c0 --- /dev/null +++ b/models/spring-ai-stabilityai/src/test/java/org/springframework/ai/stabilityai/StabilityAiApiIT.java @@ -0,0 +1,64 @@ +package org.springframework.ai.stabilityai; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.stabilityai.api.StabilityAiApi; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.util.Base64; +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; + +@EnabledIfEnvironmentVariable(named = "STABILITYAI_API_KEY", matches = ".*") +public class StabilityAiApiIT { + + StabilityAiApi stabilityAiApi = new StabilityAiApi(System.getenv("STABILITYAI_API_KEY")); + + @Test + void generateImage() throws IOException { + + List textPrompts = List + .of(new StabilityAiApi.GenerateImageRequest.TextPrompts( + "A light cream colored mini golden doodle holding a sign that says 'Heading to BARCADE !'", 0.5f)); + var builder = StabilityAiApi.GenerateImageRequest.builder() + .withTextPrompts(textPrompts) + .withHeight(1024) + .withWidth(1024) + .withCfgScale(7f) + .withSamples(1) + .withSeed(123L) + .withSteps(30) + .withStylePreset("photographic"); + StabilityAiApi.GenerateImageRequest request = builder.build(); + StabilityAiApi.GenerateImageResponse response = stabilityAiApi.generateImage(request); + + assertThat(response).isNotNull(); + List artifacts = response.artifacts(); + writeToFile(artifacts); + assertThat(artifacts).hasSize(1); + var firstArtifact = artifacts.get(0); + assertThat(firstArtifact.base64()).isNotEmpty(); + assertThat(firstArtifact.seed()).isPositive(); + assertThat(firstArtifact.finishReason()).isEqualTo("SUCCESS"); + + } + + private static void writeToFile(List artifacts) throws IOException { + int counter = 0; + String systemTempDir = System.getProperty("java.io.tmpdir"); + for (StabilityAiApi.GenerateImageResponse.Artifacts artifact : artifacts) { + counter++; + byte[] imageBytes = Base64.getDecoder().decode(artifact.base64()); + String fileName = String.format("dog%d.png", counter); + String filePath = systemTempDir + File.separator + fileName; + File file = new File(filePath); + try (FileOutputStream fos = new FileOutputStream(file)) { + fos.write(imageBytes); + } + } + } + +} diff --git a/models/spring-ai-stabilityai/src/test/java/org/springframework/ai/stabilityai/StabilityAiImageClientIT.java b/models/spring-ai-stabilityai/src/test/java/org/springframework/ai/stabilityai/StabilityAiImageClientIT.java new file mode 100644 index 00000000000..ea6f74d5734 --- /dev/null +++ b/models/spring-ai-stabilityai/src/test/java/org/springframework/ai/stabilityai/StabilityAiImageClientIT.java @@ -0,0 +1,50 @@ +package org.springframework.ai.stabilityai; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.image.*; +import org.springframework.ai.stabilityai.api.StabilityAiApi; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; + +import java.io.File; +import java.io.FileNotFoundException; +import java.io.FileOutputStream; +import java.io.IOException; +import java.util.Base64; + +import static org.assertj.core.api.Assertions.assertThat; + +@SpringBootTest(classes = StabilityAiImageTestConfiguration.class) +@EnabledIfEnvironmentVariable(named = "STABILITYAI_API_KEY", matches = ".*") +public class StabilityAiImageClientIT { + + @Autowired + protected ImageClient stabilityAiImageClient; + + @Test + void imageAsBase64Test() throws IOException { + ImagePrompt imagePrompt = new ImagePrompt( + "A light cream colored mini golden doodle holding a sign that says 'I want to go with you on vacation!'"); + + ImageResponse imageResponse = this.stabilityAiImageClient.call(imagePrompt); + + ImageGeneration imageGeneration = imageResponse.getResult(); + Image image = imageGeneration.getOutput(); + + assertThat(image.getB64Json()).isNotEmpty(); + + writeFile(image); + } + + private static void writeFile(Image image) throws IOException { + byte[] imageBytes = Base64.getDecoder().decode(image.getB64Json()); + String systemTempDir = System.getProperty("java.io.tmpdir"); + String filePath = systemTempDir + File.separator + "dog.png"; + File file = new File(filePath); + try (FileOutputStream fos = new FileOutputStream(file)) { + fos.write(imageBytes); + } + } + +} diff --git a/models/spring-ai-stabilityai/src/test/java/org/springframework/ai/stabilityai/StabilityAiImageTestConfiguration.java b/models/spring-ai-stabilityai/src/test/java/org/springframework/ai/stabilityai/StabilityAiImageTestConfiguration.java new file mode 100644 index 00000000000..a4ff9ed0390 --- /dev/null +++ b/models/spring-ai-stabilityai/src/test/java/org/springframework/ai/stabilityai/StabilityAiImageTestConfiguration.java @@ -0,0 +1,30 @@ +package org.springframework.ai.stabilityai; + +import org.springframework.ai.stabilityai.api.StabilityAiApi; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.context.annotation.Bean; +import org.springframework.util.StringUtils; + +@SpringBootConfiguration +public class StabilityAiImageTestConfiguration { + + @Bean + public StabilityAiApi stabilityAiApi() { + return new StabilityAiApi(getApiKey()); + } + + @Bean + StabilityAiImageClient stabilityAiImageClient(StabilityAiApi stabilityAiApi) { + return new StabilityAiImageClient(stabilityAiApi); + } + + private String getApiKey() { + String apiKey = System.getenv("STABILITYAI_API_KEY"); + if (!StringUtils.hasText(apiKey)) { + throw new IllegalArgumentException( + "You must provide an API key. Put it in an environment variable under the name STABILITYAI_API_KEY"); + } + return apiKey; + } + +} diff --git a/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/ResourceCacheService.java b/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/ResourceCacheService.java index b8d7695fe8e..094dd2438e0 100644 --- a/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/ResourceCacheService.java +++ b/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/ResourceCacheService.java @@ -56,7 +56,7 @@ public class ResourceCacheService { private List excludedUriSchemas = new ArrayList<>(List.of("file", "classpath")); public ResourceCacheService() { - this(new File(System.getProperty("java.io.tmpdir"), "spring-ai-onnx-model").getAbsolutePath()); + this(new File(System.getProperty("java.io.tmpdir"), "spring-ai-onnx-generative").getAbsolutePath()); } public ResourceCacheService(String rootCacheDirectory) { diff --git a/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/TransformersEmbeddingClient.java b/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/TransformersEmbeddingClient.java index 66b1dec0f5b..087510f279f 100644 --- a/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/TransformersEmbeddingClient.java +++ b/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/TransformersEmbeddingClient.java @@ -41,10 +41,10 @@ public class TransformersEmbeddingClient extends AbstractEmbeddingClient impleme private static final Log logger = LogFactory.getLog(TransformersEmbeddingClient.class); - // ONNX tokenizer for the all-MiniLM-L6-v2 model + // ONNX tokenizer for the all-MiniLM-L6-v2 generative public final static String DEFAULT_ONNX_TOKENIZER_URI = "https://raw.githubusercontent.com/spring-projects/spring-ai/main/models/spring-ai-transformers/src/main/resources/onnx/all-MiniLM-L6-v2/tokenizer.json"; - // ONNX model for all-MiniLM-L6-v2 pre-trained transformer: + // ONNX generative for all-MiniLM-L6-v2 pre-trained transformer: // https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2 public final static String DEFAULT_ONNX_MODEL_URI = "https://github.com/spring-projects/spring-ai/raw/main/models/spring-ai-transformers/src/main/resources/onnx/all-MiniLM-L6-v2/model.onnx"; @@ -70,7 +70,7 @@ public class TransformersEmbeddingClient extends AbstractEmbeddingClient impleme private OrtEnvironment environment; /** - * Runtime session that wraps the ONNX model and enables inference calls. + * Runtime session that wraps the ONNX generative and enables inference calls. */ private OrtSession session; @@ -181,7 +181,7 @@ public void afterPropertiesSet() throws Exception { logger.info("Model output names: " + onnxModelOutputs.stream().collect(Collectors.joining(", "))); Assert.isTrue(onnxModelOutputs.contains(this.modelOutputName), - "The model output names doesn't contain expected: " + this.modelOutputName); + "The generative output names doesn't contain expected: " + this.modelOutputName); } private Resource getCachedResource(Resource resource) { diff --git a/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/samples/ONNXSample.java b/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/samples/ONNXSample.java index 26f5076e1c2..887f3fe9d46 100644 --- a/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/samples/ONNXSample.java +++ b/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/samples/ONNXSample.java @@ -58,7 +58,7 @@ public static NDArray meanPooling(NDArray tokenEmbeddings, NDArray attentionMask public static void main(String[] args) throws Exception { String TOKENIZER_URI = "classpath:/onnx/tokenizer.json"; - String MODEL_URI = "classpath:/onnx/model.onnx"; + String MODEL_URI = "classpath:/onnx/generative.onnx"; var tokenizerResource = new DefaultResourceLoader().getResource(TOKENIZER_URI); var modelResource = new DefaultResourceLoader().getResource(MODEL_URI); diff --git a/models/spring-ai-vertex-ai/src/main/java/org/springframework/ai/vertex/VertexAiChatClient.java b/models/spring-ai-vertex-ai/src/main/java/org/springframework/ai/vertex/VertexAiChatClient.java index c7ec053c31f..1c922978180 100644 --- a/models/spring-ai-vertex-ai/src/main/java/org/springframework/ai/vertex/VertexAiChatClient.java +++ b/models/spring-ai-vertex-ai/src/main/java/org/springframework/ai/vertex/VertexAiChatClient.java @@ -22,8 +22,8 @@ import org.springframework.ai.chat.ChatClient; import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.Generation; -import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.messages.MessageType; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.vertex.api.VertexAiApi; import org.springframework.ai.vertex.api.VertexAiApi.GenerateMessageRequest; import org.springframework.ai.vertex.api.VertexAiApi.GenerateMessageResponse; @@ -71,15 +71,15 @@ public VertexAiChatClient withCandidateCount(Integer maxTokens) { } @Override - public ChatResponse generate(Prompt prompt) { + public ChatResponse call(Prompt prompt) { - String vertexContext = prompt.getMessages() + String vertexContext = prompt.getInstructions() .stream() .filter(m -> m.getMessageType() == MessageType.SYSTEM) .map(m -> m.getContent()) .collect(Collectors.joining("\n")); - List vertexMessages = prompt.getMessages() + List vertexMessages = prompt.getInstructions() .stream() .filter(m -> m.getMessageType() == MessageType.USER || m.getMessageType() == MessageType.ASSISTANT) .map(m -> new VertexAiApi.Message(m.getMessageType().getValue(), m.getContent())) diff --git a/models/spring-ai-vertex-ai/src/main/java/org/springframework/ai/vertex/api/VertexAiApi.java b/models/spring-ai-vertex-ai/src/main/java/org/springframework/ai/vertex/api/VertexAiApi.java index 9e9609dd20f..811be1fdc91 100644 --- a/models/spring-ai-vertex-ai/src/main/java/org/springframework/ai/vertex/api/VertexAiApi.java +++ b/models/spring-ai-vertex-ai/src/main/java/org/springframework/ai/vertex/api/VertexAiApi.java @@ -112,7 +112,7 @@ public class VertexAiApi { private final String embeddingModel; /** - * Create an new chat completion api. + * Create a new chat completion api. * @param apiKey vertex apiKey. */ public VertexAiApi(String apiKey) { @@ -120,7 +120,7 @@ public VertexAiApi(String apiKey) { } /** - * Create an new chat completion api. + * Create a new chat completion api. * @param baseUrl api base URL. * @param apiKey vertex apiKey. * @param model vertex model. diff --git a/models/spring-ai-vertex-ai/src/test/java/org/springframework/ai/vertex/api/VertexAiApiTests.java b/models/spring-ai-vertex-ai/src/test/java/org/springframework/ai/vertex/api/VertexAiApiTests.java index 78645d27a2a..ba6182f652b 100644 --- a/models/spring-ai-vertex-ai/src/test/java/org/springframework/ai/vertex/api/VertexAiApiTests.java +++ b/models/spring-ai-vertex-ai/src/test/java/org/springframework/ai/vertex/api/VertexAiApiTests.java @@ -78,7 +78,7 @@ public void generateMessage() throws JsonProcessingException { List.of(new VertexAiApi.GenerateMessageResponse.ContentFilter(BlockedReason.SAFETY, "reason"))); server - .expect(requestToUriTemplate("/models/{model}:generateMessage?key={apiKey}", + .expect(requestToUriTemplate("/models/{generative}:generateMessage?key={apiKey}", VertexAiApi.DEFAULT_GENERATE_MODEL, TEST_API_KEY)) .andExpect(method(HttpMethod.POST)) .andExpect(content().json(objectMapper.writeValueAsString(request))) @@ -99,8 +99,8 @@ public void embedText() throws JsonProcessingException { Embedding expectedEmbedding = new Embedding(List.of(0.1, 0.2, 0.3)); server - .expect(requestToUriTemplate("/models/{model}:embedText?key={apiKey}", VertexAiApi.DEFAULT_EMBEDDING_MODEL, - TEST_API_KEY)) + .expect(requestToUriTemplate("/models/{generative}:embedText?key={apiKey}", + VertexAiApi.DEFAULT_EMBEDDING_MODEL, TEST_API_KEY)) .andExpect(method(HttpMethod.POST)) .andExpect(content().json(objectMapper.writeValueAsString(Map.of("text", text)))) .andRespond(withSuccess(objectMapper.writeValueAsString(Map.of("embedding", expectedEmbedding)), @@ -122,7 +122,7 @@ public void batchEmbedText() throws JsonProcessingException { new Embedding(List.of(0.4, 0.5, 0.6))); server - .expect(requestToUriTemplate("/models/{model}:batchEmbedText?key={apiKey}", + .expect(requestToUriTemplate("/models/{generative}:batchEmbedText?key={apiKey}", VertexAiApi.DEFAULT_EMBEDDING_MODEL, TEST_API_KEY)) .andExpect(method(HttpMethod.POST)) .andExpect(content().json(objectMapper.writeValueAsString(Map.of("texts", texts)))) diff --git a/models/spring-ai-vertex-ai/src/test/java/org/springframework/ai/vertex/generation/VertexAiChatGenerationClientIT.java b/models/spring-ai-vertex-ai/src/test/java/org/springframework/ai/vertex/generation/VertexAiChatGenerationClientIT.java index 69bd76d686e..90bc23e8e8a 100644 --- a/models/spring-ai-vertex-ai/src/test/java/org/springframework/ai/vertex/generation/VertexAiChatGenerationClientIT.java +++ b/models/spring-ai-vertex-ai/src/test/java/org/springframework/ai/vertex/generation/VertexAiChatGenerationClientIT.java @@ -12,11 +12,11 @@ import org.springframework.ai.parser.BeanOutputParser; import org.springframework.ai.parser.ListOutputParser; import org.springframework.ai.parser.MapOutputParser; -import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.PromptTemplate; -import org.springframework.ai.prompt.SystemPromptTemplate; -import org.springframework.ai.prompt.messages.Message; -import org.springframework.ai.prompt.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.vertex.VertexAiChatClient; import org.springframework.ai.vertex.api.VertexAiApi; import org.springframework.beans.factory.annotation.Autowired; @@ -48,8 +48,8 @@ void roleTest() { SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", name, "voice", voice)); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - ChatResponse response = client.generate(prompt); - assertThat(response.getGeneration().getContent()).contains("Bartholomew"); + ChatResponse response = client.call(prompt); + assertThat(response.getResult().getOutput().getContent()).contains("Bartholomew"); } // @Test @@ -65,9 +65,9 @@ void outputParser() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "ice cream flavors.", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = this.client.generate(prompt).getGeneration(); + Generation generation = this.client.call(prompt).getResult(); - List list = outputParser.parse(generation.getContent()); + List list = outputParser.parse(generation.getOutput().getContent()); assertThat(list).hasSize(5); } @@ -84,9 +84,9 @@ void mapOutputParser() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = client.generate(prompt).getGeneration(); + Generation generation = client.call(prompt).getResult(); - Map result = outputParser.parse(generation.getContent()); + Map result = outputParser.parse(generation.getOutput().getContent()); assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); } @@ -106,9 +106,9 @@ void beanOutputParserRecords() { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = client.generate(prompt).getGeneration(); + Generation generation = client.call(prompt).getResult(); - ActorsFilmsRecord actorsFilms = outputParser.parse(generation.getContent()); + ActorsFilmsRecord actorsFilms = outputParser.parse(generation.getOutput().getContent()); assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); assertThat(actorsFilms.movies()).hasSize(5); } diff --git a/pom.xml b/pom.xml index 338ef2c7522..09e012bc6d8 100644 --- a/pom.xml +++ b/pom.xml @@ -14,14 +14,15 @@ spring-ai-core + models/spring-ai-transformers + models/spring-ai-postgresml models/spring-ai-bedrock models/spring-ai-azure-openai models/spring-ai-huggingface models/spring-ai-ollama models/spring-ai-openai models/spring-ai-vertex-ai - models/spring-ai-transformers - models/spring-ai-postgresml + models/spring-ai-stabilityai spring-ai-test spring-ai-spring-boot-autoconfigure spring-ai-spring-boot-starters/spring-ai-starter-openai diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatClient.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatClient.java index 373339467d5..89162ab2d4b 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatClient.java @@ -16,17 +16,18 @@ package org.springframework.ai.chat; -import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.model.ModelClient; @FunctionalInterface -public interface ChatClient { +public interface ChatClient extends ModelClient { - default String generate(String message) { + default String call(String message) { Prompt prompt = new Prompt(new UserMessage(message)); - return generate(prompt).getGeneration().getContent(); + return call(prompt).getResult().getOutput().getContent(); } - ChatResponse generate(Prompt prompt); + ChatResponse call(Prompt prompt); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatOptions.java new file mode 100644 index 00000000000..f6cd9053159 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatOptions.java @@ -0,0 +1,40 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat; + +import org.springframework.ai.model.ModelOptions; + +/** + * portable options + */ +public interface ChatOptions extends ModelOptions { + + // determine portable optionsb + + Float getTemperature(); + + void setTemperature(Float temperature); + + Float getTopP(); + + void setTopP(Float topP); + + Integer getTopK(); + + void setTopK(Integer topK); + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatOptionsBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatOptionsBuilder.java new file mode 100644 index 00000000000..60c72987215 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatOptionsBuilder.java @@ -0,0 +1,89 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat; + +public class ChatOptionsBuilder { + + private class ChatOptionsImpl implements ChatOptions { + + private Float temperature; + + private Float topP; + + private Integer topK; + + @Override + public Float getTemperature() { + return temperature; + } + + @Override + public void setTemperature(Float temperature) { + this.temperature = temperature; + } + + @Override + public Float getTopP() { + return topP; + } + + @Override + public void setTopP(Float topP) { + this.topP = topP; + } + + @Override + public Integer getTopK() { + return topK; + } + + @Override + public void setTopK(Integer topK) { + this.topK = topK; + } + + } + + private final ChatOptionsImpl options = new ChatOptionsImpl(); + + private ChatOptionsBuilder() { + } + + public static ChatOptionsBuilder builder() { + return new ChatOptionsBuilder(); + } + + public ChatOptionsBuilder withTemperature(Float temperature) { + options.setTemperature(temperature); + return this; + } + + public ChatOptionsBuilder withTopP(Float topP) { + options.setTopP(topP); + return this; + } + + public ChatOptionsBuilder withTopK(Integer topK) { + options.setTopK(topK); + return this; + } + + public ChatOptions build() { + return options; + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatResponse.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatResponse.java index f68f6bb317a..6ad2dab3e8c 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatResponse.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatResponse.java @@ -17,42 +17,39 @@ import java.util.List; -import org.springframework.ai.metadata.GenerationMetadata; -import org.springframework.ai.metadata.PromptMetadata; -import org.springframework.lang.Nullable; +import org.springframework.ai.model.ModelResponse; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; /** * The chat completion (e.g. generation) response returned by an AI provider. */ -public class ChatResponse { +public class ChatResponse implements ModelResponse { - private final GenerationMetadata metadata; + private final ChatResponseMetadata chatResponseMetadata; /** * List of generated messages returned by the AI provider. */ private final List generations; - private PromptMetadata promptMetadata; - /** * Construct a new {@link ChatResponse} instance without metadata. * @param generations the {@link List} of {@link Generation} returned by the AI * provider. */ public ChatResponse(List generations) { - this(generations, GenerationMetadata.NULL); + this(generations, ChatResponseMetadata.NULL); } /** * Construct a new {@link ChatResponse} instance. * @param generations the {@link List} of {@link Generation} returned by the AI * provider. - * @param metadata {@link GenerationMetadata} containing information about the use of - * the AI provider's API. + * @param chatResponseMetadata {@link ChatResponseMetadata} containing information + * about the use of the AI provider's API. */ - public ChatResponse(List generations, GenerationMetadata metadata) { - this.metadata = metadata; + public ChatResponse(List generations, ChatResponseMetadata chatResponseMetadata) { + this.chatResponseMetadata = chatResponseMetadata; this.generations = List.copyOf(generations); } @@ -63,45 +60,25 @@ public ChatResponse(List generations, GenerationMetadata metadata) { * multiple output {@link Generation generations}. * @return the {@link List} of {@link Generation generated outputs}. */ - public List getGenerations() { + + @Override + public List getResults() { return this.generations; } /** * @return Returns the first {@link Generation} in the generations list. */ - public Generation getGeneration() { + public Generation getResult() { return this.generations.get(0); } /** - * @return Returns {@link GenerationMetadata} containing information about the use of - * the AI provider's API. - */ - public GenerationMetadata getGenerationMetadata() { - return this.metadata; - } - - /** - * @return {@link PromptMetadata} containing information on prompt processing by the - * AI. - */ - public PromptMetadata getPromptMetadata() { - PromptMetadata promptMetadata = this.promptMetadata; - return promptMetadata != null ? promptMetadata : PromptMetadata.empty(); - } - - /** - * Builder method used to include {@link PromptMetadata} returned in the AI response - * when processing the prompt. - * @param promptMetadata {@link PromptMetadata} returned by the AI in the response - * when processing the prompt. - * @return this {@link ChatResponse}. - * @see #getPromptMetadata() + * @return Returns {@link ChatResponseMetadata} containing information about the use + * of the AI provider's API. */ - public ChatResponse withPromptMetadata(@Nullable PromptMetadata promptMetadata) { - this.promptMetadata = promptMetadata; - return this; + public ChatResponseMetadata getMetadata() { + return this.chatResponseMetadata; } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/Generation.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/Generation.java index 71fd3455a45..6234a0d562b 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/Generation.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/Generation.java @@ -16,46 +16,65 @@ package org.springframework.ai.chat; -import java.util.Collections; import java.util.Map; +import java.util.Objects; -import org.springframework.ai.metadata.ChoiceMetadata; -import org.springframework.ai.prompt.messages.AbstractMessage; -import org.springframework.ai.prompt.messages.MessageType; +import org.springframework.ai.model.ModelResult; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; +import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.lang.Nullable; /** * Represents a response returned by the AI. */ -public class Generation extends AbstractMessage { +public class Generation implements ModelResult { - private ChoiceMetadata choiceMetadata; + private AssistantMessage assistantMessage; + + private ChatGenerationMetadata chatGenerationMetadata; public Generation(String text) { - this(text, Collections.emptyMap()); + this.assistantMessage = new AssistantMessage(text); } - public Generation(String content, Map properties) { - super(MessageType.ASSISTANT, content, properties); + public Generation(String text, Map properties) { + this.assistantMessage = new AssistantMessage(text, properties); } - public Generation(String content, Map properties, MessageType type) { - super(type, content, properties); + @Override + public AssistantMessage getOutput() { + return this.assistantMessage; } - public ChoiceMetadata getChoiceMetadata() { - ChoiceMetadata choiceMetadata = this.choiceMetadata; - return choiceMetadata != null ? choiceMetadata : ChoiceMetadata.NULL; + public ChatGenerationMetadata getMetadata() { + ChatGenerationMetadata chatGenerationMetadata = this.chatGenerationMetadata; + return chatGenerationMetadata != null ? chatGenerationMetadata : ChatGenerationMetadata.NULL; } - public Generation withChoiceMetadata(@Nullable ChoiceMetadata choiceMetadata) { - this.choiceMetadata = choiceMetadata; + public Generation withGenerationMetadata(@Nullable ChatGenerationMetadata chatGenerationMetadata) { + this.chatGenerationMetadata = chatGenerationMetadata; return this; } + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (!(o instanceof Generation that)) + return false; + return Objects.equals(assistantMessage, that.assistantMessage) + && Objects.equals(chatGenerationMetadata, that.chatGenerationMetadata); + } + + @Override + public int hashCode() { + return Objects.hash(assistantMessage, chatGenerationMetadata); + } + @Override public String toString() { - return "Generation{" + "text='" + content + '\'' + ", info=" + properties + '}'; + return "Generation{" + "assistantMessage=" + assistantMessage + ", chatGenerationMetadata=" + + chatGenerationMetadata + '}'; } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/StreamingChatClient.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/StreamingChatClient.java index d697636151a..beb67ae6a3c 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/StreamingChatClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/StreamingChatClient.java @@ -18,7 +18,7 @@ import reactor.core.publisher.Flux; -import org.springframework.ai.prompt.Prompt; +import org.springframework.ai.chat.prompt.Prompt; @FunctionalInterface public interface StreamingChatClient { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/AbstractMessage.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/AbstractMessage.java similarity index 97% rename from spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/AbstractMessage.java rename to spring-ai-core/src/main/java/org/springframework/ai/chat/messages/AbstractMessage.java index 9f4c123ecd0..24ecf3c1e06 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/AbstractMessage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/AbstractMessage.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.ai.prompt.messages; +package org.springframework.ai.chat.messages; import org.springframework.core.io.Resource; import org.springframework.util.StreamUtils; @@ -31,7 +31,7 @@ public abstract class AbstractMessage implements Message { protected String content; /** - * Additional options for the message to influence the response, not a model map. + * Additional options for the message to influence the response, not a generative map. */ protected Map properties = new HashMap<>(); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/AssistantMessage.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/AssistantMessage.java similarity index 82% rename from spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/AssistantMessage.java rename to spring-ai-core/src/main/java/org/springframework/ai/chat/messages/AssistantMessage.java index 2e064bd65f8..02e29146953 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/AssistantMessage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/AssistantMessage.java @@ -14,14 +14,14 @@ * limitations under the License. */ -package org.springframework.ai.prompt.messages; +package org.springframework.ai.chat.messages; import java.util.Map; /** - * Lets the model know the content was generated as a response to the user. This role - * indicates messages that the model has previously generated in the conversation. By - * including assistant messages in the series, you provide context to the model about + * Lets the generative know the content was generated as a response to the user. This role + * indicates messages that the generative has previously generated in the conversation. By + * including assistant messages in the series, you provide context to the generative about * prior exchanges in the conversation. */ public class AssistantMessage extends AbstractMessage { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/ChatMessage.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/ChatMessage.java similarity index 96% rename from spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/ChatMessage.java rename to spring-ai-core/src/main/java/org/springframework/ai/chat/messages/ChatMessage.java index 6dc660ccb27..194aa54afaf 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/ChatMessage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/ChatMessage.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.ai.prompt.messages; +package org.springframework.ai.chat.messages; import java.util.Map; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/FunctionMessage.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/FunctionMessage.java similarity index 95% rename from spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/FunctionMessage.java rename to spring-ai-core/src/main/java/org/springframework/ai/chat/messages/FunctionMessage.java index c00520dcdda..5fa90f50be0 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/FunctionMessage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/FunctionMessage.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.ai.prompt.messages; +package org.springframework.ai.chat.messages; import java.util.Map; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/Message.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/Message.java similarity index 94% rename from spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/Message.java rename to spring-ai-core/src/main/java/org/springframework/ai/chat/messages/Message.java index a405eb10b81..e1e6eda348c 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/Message.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/Message.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.ai.prompt.messages; +package org.springframework.ai.chat.messages; import java.util.Map; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/MessageType.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/MessageType.java similarity index 95% rename from spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/MessageType.java rename to spring-ai-core/src/main/java/org/springframework/ai/chat/messages/MessageType.java index 267727c7c00..6cb07a045ba 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/MessageType.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/MessageType.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.prompt.messages; +package org.springframework.ai.chat.messages; public enum MessageType { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/SystemMessage.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/SystemMessage.java similarity index 88% rename from spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/SystemMessage.java rename to spring-ai-core/src/main/java/org/springframework/ai/chat/messages/SystemMessage.java index 31179045122..ac70e593238 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/SystemMessage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/SystemMessage.java @@ -14,15 +14,16 @@ * limitations under the License. */ -package org.springframework.ai.prompt.messages; +package org.springframework.ai.chat.messages; import org.springframework.core.io.Resource; /** * A message of the type 'system' passed as input. The system message gives high level * instructions for the conversation. This role typically provides high-level instructions - * for the conversation. For example, you might use a system message to instruct the model - * to behave like a certain character or to provide answers in a specific format. + * for the conversation. For example, you might use a system message to instruct the + * generative to behave like a certain character or to provide answers in a specific + * format. */ public class SystemMessage extends AbstractMessage { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/UserMessage.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/UserMessage.java similarity index 93% rename from spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/UserMessage.java rename to spring-ai-core/src/main/java/org/springframework/ai/chat/messages/UserMessage.java index 8c7fdcae6fa..6b91cc1d2e1 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/UserMessage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/UserMessage.java @@ -14,14 +14,14 @@ * limitations under the License. */ -package org.springframework.ai.prompt.messages; +package org.springframework.ai.chat.messages; import org.springframework.core.io.Resource; /** * A message of the type 'user' passed as input Messages with the user role are from the * end-user or developer. They represent questions, prompts, or any input that you want - * the model to respond to. + * the generative to respond to. */ public class UserMessage extends AbstractMessage { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/metadata/AbstractRateLimit.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/AbstractRateLimit.java similarity index 96% rename from spring-ai-core/src/main/java/org/springframework/ai/metadata/AbstractRateLimit.java rename to spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/AbstractRateLimit.java index 9b0b3a80315..f485c7a294a 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/metadata/AbstractRateLimit.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/AbstractRateLimit.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.ai.metadata; +package org.springframework.ai.chat.metadata; import java.time.Duration; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/metadata/AbstractUsage.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/AbstractUsage.java similarity index 95% rename from spring-ai-core/src/main/java/org/springframework/ai/metadata/AbstractUsage.java rename to spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/AbstractUsage.java index 2adad4b9de3..edec477abac 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/metadata/AbstractUsage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/AbstractUsage.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.ai.metadata; +package org.springframework.ai.chat.metadata; /** * Abstract base class used as a foundation for implementing {@link Usage}. diff --git a/spring-ai-core/src/main/java/org/springframework/ai/metadata/ChoiceMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/ChatGenerationMetadata.java similarity index 73% rename from spring-ai-core/src/main/java/org/springframework/ai/metadata/ChoiceMetadata.java rename to spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/ChatGenerationMetadata.java index 9c15fbdb226..d9f5fc56eb7 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/metadata/ChoiceMetadata.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/ChatGenerationMetadata.java @@ -14,8 +14,9 @@ * limitations under the License. */ -package org.springframework.ai.metadata; +package org.springframework.ai.chat.metadata; +import org.springframework.ai.model.ResultMetadata; import org.springframework.lang.Nullable; /** @@ -25,21 +26,21 @@ * @author John Blum * @since 0.7.0 */ -public interface ChoiceMetadata { +public interface ChatGenerationMetadata extends ResultMetadata { - ChoiceMetadata NULL = ChoiceMetadata.from(null, null); + ChatGenerationMetadata NULL = ChatGenerationMetadata.from(null, null); /** - * Factory method used to construct a new {@link ChoiceMetadata} from the given - * {@link String finish reason} and content filter metadata. + * Factory method used to construct a new {@link ChatGenerationMetadata} from the + * given {@link String finish reason} and content filter metadata. * @param finishReason {@link String} contain the reason for the choice completion. * @param contentFilterMetadata underlying AI provider metadata for filtering applied * to generation content. - * @return a new {@link ChoiceMetadata} from the given {@link String finish reason} - * and content filter metadata. + * @return a new {@link ChatGenerationMetadata} from the given {@link String finish + * reason} and content filter metadata. */ - static ChoiceMetadata from(String finishReason, Object contentFilterMetadata) { - return new ChoiceMetadata() { + static ChatGenerationMetadata from(String finishReason, Object contentFilterMetadata) { + return new ChatGenerationMetadata() { @Override @SuppressWarnings("unchecked") diff --git a/spring-ai-core/src/main/java/org/springframework/ai/metadata/GenerationMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/ChatResponseMetadata.java similarity index 79% rename from spring-ai-core/src/main/java/org/springframework/ai/metadata/GenerationMetadata.java rename to spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/ChatResponseMetadata.java index 3bd641eca24..12c9d0b5c8f 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/metadata/GenerationMetadata.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/ChatResponseMetadata.java @@ -14,7 +14,9 @@ * limitations under the License. */ -package org.springframework.ai.metadata; +package org.springframework.ai.chat.metadata; + +import org.springframework.ai.model.ResponseMetadata; /** * Abstract Data Type (ADT) modeling common AI provider metadata returned in an AI @@ -23,9 +25,9 @@ * @author John Blum * @since 0.7.0 */ -public interface GenerationMetadata { +public interface ChatResponseMetadata extends ResponseMetadata { - GenerationMetadata NULL = new GenerationMetadata() { + ChatResponseMetadata NULL = new ChatResponseMetadata() { }; /** @@ -46,4 +48,8 @@ default Usage getUsage() { return Usage.NULL; } + default PromptMetadata getPromptMetadata() { + return PromptMetadata.empty(); + } + } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/metadata/PromptMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/PromptMetadata.java similarity index 98% rename from spring-ai-core/src/main/java/org/springframework/ai/metadata/PromptMetadata.java rename to spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/PromptMetadata.java index 302481b6e38..6ecb0924603 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/metadata/PromptMetadata.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/PromptMetadata.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.metadata; +package org.springframework.ai.chat.metadata; import java.util.Arrays; import java.util.Optional; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/metadata/RateLimit.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/RateLimit.java similarity index 98% rename from spring-ai-core/src/main/java/org/springframework/ai/metadata/RateLimit.java rename to spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/RateLimit.java index 40bb933d03b..66bdebd2b40 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/metadata/RateLimit.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/RateLimit.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.ai.metadata; +package org.springframework.ai.chat.metadata; import java.time.Duration; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/metadata/Usage.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/Usage.java similarity index 97% rename from spring-ai-core/src/main/java/org/springframework/ai/metadata/Usage.java rename to spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/Usage.java index 7179d5e127e..9cc1dbb711a 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/metadata/Usage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/Usage.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.ai.metadata; +package org.springframework.ai.chat.metadata; /** * Abstract Data Type (ADT) encapsulating metadata on the usage of an AI provider's API diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/package-info.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/package-info.java new file mode 100644 index 00000000000..98d92eb7127 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/package-info.java @@ -0,0 +1,14 @@ +/** + * The org.sf.ai.chat package represents the bounded context for the Chat Model within the + * AI generative model domain. This package extends the core domain defined in + * org.sf.ai.generative, providing implementations specific to chat-based generative AI + * interactions. + * + * In line with Domain-Driven Design principles, this package includes implementations of + * entities and value objects specific to the chat context, such as ChatPrompt and + * ChatResponse, adhering to the ubiquitous language of chat interactions in AI models. + * + * This bounded context is designed to encapsulate all aspects of chat-based AI + * functionalities, maintaining a clear boundary from other contexts within the AI domain. + */ +package org.springframework.ai.chat; \ No newline at end of file diff --git a/spring-ai-core/src/main/java/org/springframework/ai/prompt/AssistantPromptTemplate.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/AssistantPromptTemplate.java similarity index 90% rename from spring-ai-core/src/main/java/org/springframework/ai/prompt/AssistantPromptTemplate.java rename to spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/AssistantPromptTemplate.java index 715763f47fe..8a167f00590 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/prompt/AssistantPromptTemplate.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/AssistantPromptTemplate.java @@ -14,9 +14,9 @@ * limitations under the License. */ -package org.springframework.ai.prompt; +package org.springframework.ai.chat.prompt; -import org.springframework.ai.prompt.messages.AssistantMessage; +import org.springframework.ai.chat.messages.AssistantMessage; import java.util.Map; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/prompt/ChatPromptTemplate.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatPromptTemplate.java similarity index 95% rename from spring-ai-core/src/main/java/org/springframework/ai/prompt/ChatPromptTemplate.java rename to spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatPromptTemplate.java index dbef2e4eab9..800c4586ecb 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/prompt/ChatPromptTemplate.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatPromptTemplate.java @@ -14,9 +14,9 @@ * limitations under the License. */ -package org.springframework.ai.prompt; +package org.springframework.ai.chat.prompt; -import org.springframework.ai.prompt.messages.Message; +import org.springframework.ai.chat.messages.Message; import java.util.ArrayList; import java.util.List; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/prompt/FunctionPromptTemplate.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/FunctionPromptTemplate.java similarity index 94% rename from spring-ai-core/src/main/java/org/springframework/ai/prompt/FunctionPromptTemplate.java rename to spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/FunctionPromptTemplate.java index 41504364ffd..4c7ce981ffc 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/prompt/FunctionPromptTemplate.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/FunctionPromptTemplate.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.ai.prompt; +package org.springframework.ai.chat.prompt; public class FunctionPromptTemplate extends PromptTemplate { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/prompt/Prompt.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/Prompt.java similarity index 53% rename from spring-ai-core/src/main/java/org/springframework/ai/prompt/Prompt.java rename to spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/Prompt.java index b0acf5fb4c4..29f671e7b57 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/prompt/Prompt.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/Prompt.java @@ -14,19 +14,23 @@ * limitations under the License. */ -package org.springframework.ai.prompt; +package org.springframework.ai.chat.prompt; -import org.springframework.ai.prompt.messages.Message; -import org.springframework.ai.prompt.messages.UserMessage; +import org.springframework.ai.model.ModelOptions; +import org.springframework.ai.model.ModelRequest; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; import java.util.Collections; import java.util.List; import java.util.Objects; -public class Prompt { +public class Prompt implements ModelRequest> { private final List messages; + private ModelOptions modelOptions; + public Prompt(String contents) { this(new UserMessage(contents)); } @@ -39,36 +43,53 @@ public Prompt(List messages) { this.messages = messages; } + public Prompt(String contents, ModelOptions modelOptions) { + this(new UserMessage(contents), modelOptions); + } + + public Prompt(Message message, ModelOptions modelOptions) { + this(Collections.singletonList(message), modelOptions); + } + + public Prompt(List messages, ModelOptions modelOptions) { + this.messages = messages; + this.modelOptions = modelOptions; + } + public String getContents() { StringBuilder sb = new StringBuilder(); - for (Message message : getMessages()) { + for (Message message : getInstructions()) { sb.append(message.getContent()); } return sb.toString(); } - public List getMessages() { + public ModelOptions getOptions() { + return modelOptions; + } + + @Override + public List getInstructions() { return this.messages; } @Override public String toString() { - return "Prompt{" + "messages=" + messages + '}'; + return "Prompt{" + "messages=" + messages + ", modelOptions=" + modelOptions + '}'; } @Override public boolean equals(Object o) { if (this == o) return true; - if (o == null || getClass() != o.getClass()) + if (!(o instanceof Prompt prompt)) return false; - Prompt prompt = (Prompt) o; - return Objects.equals(messages, prompt.messages); + return Objects.equals(messages, prompt.messages) && Objects.equals(modelOptions, prompt.modelOptions); } @Override public int hashCode() { - return Objects.hash(messages); + return Objects.hash(messages, modelOptions); } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/prompt/PromptTemplate.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplate.java similarity index 96% rename from spring-ai-core/src/main/java/org/springframework/ai/prompt/PromptTemplate.java rename to spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplate.java index 6207aefeadf..4ac2a1b92d9 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/prompt/PromptTemplate.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplate.java @@ -14,13 +14,13 @@ * limitations under the License. */ -package org.springframework.ai.prompt; +package org.springframework.ai.chat.prompt; import org.antlr.runtime.Token; import org.antlr.runtime.TokenStream; import org.springframework.ai.parser.OutputParser; -import org.springframework.ai.prompt.messages.Message; -import org.springframework.ai.prompt.messages.UserMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.core.io.Resource; import org.springframework.util.StreamUtils; import org.stringtemplate.v4.ST; @@ -189,7 +189,7 @@ public Prompt create(Map model) { return new Prompt(render(model)); } - protected Set getInputVariables() { + public Set getInputVariables() { TokenStream tokens = this.st.impl.tokens; return IntStream.range(0, tokens.range()) .mapToObj(tokens::get) diff --git a/spring-ai-core/src/main/java/org/springframework/ai/prompt/PromptTemplateActions.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateActions.java similarity index 94% rename from spring-ai-core/src/main/java/org/springframework/ai/prompt/PromptTemplateActions.java rename to spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateActions.java index e1637384f88..3381eaddd50 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/prompt/PromptTemplateActions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateActions.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.ai.prompt; +package org.springframework.ai.chat.prompt; import java.util.Map; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/prompt/PromptTemplateChatActions.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateChatActions.java similarity index 66% rename from spring-ai-core/src/main/java/org/springframework/ai/prompt/PromptTemplateChatActions.java rename to spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateChatActions.java index b12384c65ef..361de4ade93 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/prompt/PromptTemplateChatActions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateChatActions.java @@ -1,6 +1,6 @@ -package org.springframework.ai.prompt; +package org.springframework.ai.chat.prompt; -import org.springframework.ai.prompt.messages.Message; +import org.springframework.ai.chat.messages.Message; import java.util.List; import java.util.Map; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/prompt/PromptTemplateMessageActions.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateMessageActions.java similarity index 61% rename from spring-ai-core/src/main/java/org/springframework/ai/prompt/PromptTemplateMessageActions.java rename to spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateMessageActions.java index bf48bf5460d..47e24620c10 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/prompt/PromptTemplateMessageActions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateMessageActions.java @@ -1,6 +1,6 @@ -package org.springframework.ai.prompt; +package org.springframework.ai.chat.prompt; -import org.springframework.ai.prompt.messages.Message; +import org.springframework.ai.chat.messages.Message; import java.util.Map; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/prompt/PromptTemplateStringActions.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateStringActions.java similarity index 75% rename from spring-ai-core/src/main/java/org/springframework/ai/prompt/PromptTemplateStringActions.java rename to spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateStringActions.java index 8c041e81c13..86495a49e3e 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/prompt/PromptTemplateStringActions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateStringActions.java @@ -1,4 +1,4 @@ -package org.springframework.ai.prompt; +package org.springframework.ai.chat.prompt; import java.util.Map; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/prompt/SystemPromptTemplate.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/SystemPromptTemplate.java similarity index 89% rename from spring-ai-core/src/main/java/org/springframework/ai/prompt/SystemPromptTemplate.java rename to spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/SystemPromptTemplate.java index 73d84902c98..539287d070b 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/prompt/SystemPromptTemplate.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/SystemPromptTemplate.java @@ -14,10 +14,10 @@ * limitations under the License. */ -package org.springframework.ai.prompt; +package org.springframework.ai.chat.prompt; -import org.springframework.ai.prompt.messages.Message; -import org.springframework.ai.prompt.messages.SystemMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.core.io.Resource; import java.util.Map; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/prompt/TemplateFormat.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/TemplateFormat.java similarity index 96% rename from spring-ai-core/src/main/java/org/springframework/ai/prompt/TemplateFormat.java rename to spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/TemplateFormat.java index 9782db6b3b4..1001ee26600 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/prompt/TemplateFormat.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/TemplateFormat.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.ai.prompt; +package org.springframework.ai.chat.prompt; public enum TemplateFormat { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/document/DefaultContentFormatter.java b/spring-ai-core/src/main/java/org/springframework/ai/document/DefaultContentFormatter.java index fcc7a570329..f02a2d0b5d0 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/document/DefaultContentFormatter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/document/DefaultContentFormatter.java @@ -71,7 +71,7 @@ public class DefaultContentFormatter implements ContentFormatter { private final List excludedInferenceMetadataKeys; /** - * Metadata keys that are excluded from text for the embed model. + * Metadata keys that are excluded from text for the embed generative. */ private final List excludedEmbedMetadataKeys; @@ -157,7 +157,8 @@ public Builder withTextTemplate(String textTemplate) { } /** - * Configures the excluded Inference metadata keys to filter out from the model. + * Configures the excluded Inference metadata keys to filter out from the + * generative. * @param excludedInferenceMetadataKeys Excluded inference metadata keys to use. * @return this builder */ @@ -174,7 +175,7 @@ public Builder withExcludedInferenceMetadataKeys(String... keys) { } /** - * Configures the excluded Embed metadata keys to filter out from the model. + * Configures the excluded Embed metadata keys to filter out from the generative. * @param excludedEmbedMetadataKeys Excluded Embed metadata keys to use. * @return this builder */ diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingClient.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingClient.java index b0bd551c7ff..fdc519ad667 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingClient.java @@ -38,7 +38,8 @@ public interface EmbeddingClient { EmbeddingResponse embedForResponse(List texts); /** - * @return the number of dimensions of the embedded vectors. It is model specific. + * @return the number of dimensions of the embedded vectors. It is generative + * specific. */ default int dimensions() { return embed("Test String").size(); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingUtil.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingUtil.java index 451cea484f1..91d621b0caf 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingUtil.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingUtil.java @@ -31,11 +31,11 @@ public class EmbeddingUtil { private static Map KNOWN_EMBEDDING_DIMENSIONS = loadKnownModelDimensions(); /** - * Return the dimension of the requested embedding model name. If the model name is - * unknown uses the EmbeddingClient to perform a dummy EmbeddingClient#embed and count - * the response dimensions. + * Return the dimension of the requested embedding generative name. If the generative + * name is unknown uses the EmbeddingClient to perform a dummy EmbeddingClient#embed + * and count the response dimensions. * @param embeddingClient Fall-back client to determine, empirically the dimensions. - * @param modelName Embedding model name to retrieve the dimensions for. + * @param modelName Embedding generative name to retrieve the dimensions for. * @param dummyContent Dummy content to use for the empirical dimension calculation. * @return Returns the embedding dimensions for the modelName. */ diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/Image.java b/spring-ai-core/src/main/java/org/springframework/ai/image/Image.java new file mode 100644 index 00000000000..fa1f0b8ff9d --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/Image.java @@ -0,0 +1,73 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.image; + +import java.util.Objects; + +public class Image { + + /** + * The URL where the image can be accessed. + */ + private String url; + + /** + * Base64 encoded image string. + */ + private String b64Json; + + public Image(String url, String b64Json) { + this.url = url; + this.b64Json = b64Json; + } + + public String getUrl() { + return url; + } + + public void setUrl(String url) { + this.url = url; + } + + public String getB64Json() { + return b64Json; + } + + public void setB64Json(String b64Json) { + this.b64Json = b64Json; + } + + @Override + public String toString() { + return "Image{" + "url='" + url + '\'' + ", b64Json='" + b64Json + '\'' + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (!(o instanceof Image image)) + return false; + return Objects.equals(url, image.url) && Objects.equals(b64Json, image.b64Json); + } + + @Override + public int hashCode() { + return Objects.hash(url, b64Json); + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageClient.java b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageClient.java new file mode 100644 index 00000000000..bf06964e1ff --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageClient.java @@ -0,0 +1,26 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.image; + +import org.springframework.ai.model.ModelClient; + +@FunctionalInterface +public interface ImageClient extends ModelClient { + + ImageResponse call(ImagePrompt request); + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageGeneration.java b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageGeneration.java new file mode 100644 index 00000000000..94f739266a3 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageGeneration.java @@ -0,0 +1,51 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.image; + +import org.springframework.ai.model.ModelResult; + +public class ImageGeneration implements ModelResult { + + private ImageGenerationMetadata imageGenerationMetadata; + + private Image image; + + public ImageGeneration(Image image) { + this.image = image; + } + + public ImageGeneration(Image image, ImageGenerationMetadata imageGenerationMetadata) { + this.image = image; + this.imageGenerationMetadata = imageGenerationMetadata; + } + + @Override + public Image getOutput() { + return image; + } + + @Override + public ImageGenerationMetadata getMetadata() { + return imageGenerationMetadata; + } + + @Override + public String toString() { + return "ImageGeneration{" + "imageGenerationMetadata=" + imageGenerationMetadata + ", image=" + image + '}'; + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageGenerationMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageGenerationMetadata.java new file mode 100644 index 00000000000..e140aa814be --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageGenerationMetadata.java @@ -0,0 +1,23 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.image; + +import org.springframework.ai.model.ResultMetadata; + +public interface ImageGenerationMetadata extends ResultMetadata { + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageMessage.java b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageMessage.java new file mode 100644 index 00000000000..51d378b8c32 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageMessage.java @@ -0,0 +1,63 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.image; + +import java.util.Objects; + +public class ImageMessage { + + private String text; + + private Float weight; + + public ImageMessage(String text) { + this.text = text; + } + + public ImageMessage(String text, Float weight) { + this.text = text; + this.weight = weight; + } + + public String getText() { + return text; + } + + public Float getWeight() { + return weight; + } + + @Override + public String toString() { + return "mageMessage{" + "text='" + text + '\'' + ", weight=" + weight + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (!(o instanceof ImageMessage that)) + return false; + return Objects.equals(text, that.text) && Objects.equals(weight, that.weight); + } + + @Override + public int hashCode() { + return Objects.hash(text, weight); + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageOptions.java new file mode 100644 index 00000000000..dbfec79c9d6 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageOptions.java @@ -0,0 +1,33 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.image; + +import org.springframework.ai.model.ModelOptions; + +public interface ImageOptions extends ModelOptions { + + Integer getN(); + + String getModel(); + + Integer getWidth(); + + Integer getHeight(); + + String getResponseFormat(); // openai - url or base64 : stability ai byte[] or base64 + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageOptionsBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageOptionsBuilder.java new file mode 100644 index 00000000000..49dc3497d3e --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageOptionsBuilder.java @@ -0,0 +1,119 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.image; + +public class ImageOptionsBuilder { + + private class ImageModelOptionsImpl implements ImageOptions { + + private Integer n; + + private String model; + + private Integer width; + + private Integer height; + + private String responseFormat; + + @Override + public Integer getN() { + return n; + } + + public void setN(Integer n) { + this.n = n; + } + + @Override + public String getModel() { + return model; + } + + public void setModel(String model) { + this.model = model; + } + + @Override + public String getResponseFormat() { + return responseFormat; + } + + public void setResponseFormat(String responseFormat) { + this.responseFormat = responseFormat; + } + + @Override + public Integer getWidth() { + return width; + } + + public void setWidth(Integer width) { + this.width = width; + } + + @Override + public Integer getHeight() { + return height; + } + + public void setHeight(Integer height) { + this.height = height; + } + + } + + private final ImageModelOptionsImpl options = new ImageModelOptionsImpl(); + + private ImageOptionsBuilder() { + + } + + public static ImageOptionsBuilder builder() { + return new ImageOptionsBuilder(); + } + + public ImageOptionsBuilder withN(Integer n) { + options.setN(n); + return this; + } + + public ImageOptionsBuilder withModel(String model) { + options.setModel(model); + return this; + } + + public ImageOptionsBuilder withResponseFormat(String responseFormat) { + options.setResponseFormat(responseFormat); + return this; + } + + public ImageOptionsBuilder withWidth(Integer width) { + options.setWidth(width); + return this; + } + + public ImageOptionsBuilder withHeight(Integer height) { + options.setHeight(height); + return this; + } + + public ImageOptions build() { + return options; + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/ImagePrompt.java b/spring-ai-core/src/main/java/org/springframework/ai/image/ImagePrompt.java new file mode 100644 index 00000000000..5ea58bea469 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/ImagePrompt.java @@ -0,0 +1,81 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.image; + +import org.springframework.ai.model.ModelRequest; + +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +public class ImagePrompt implements ModelRequest> { + + private final List messages; + + private ImageOptions imageModelOptions; + + public ImagePrompt(List messages) { + this.messages = messages; + } + + public ImagePrompt(List messages, ImageOptions imageModelOptions) { + this.messages = messages; + this.imageModelOptions = imageModelOptions; + } + + public ImagePrompt(ImageMessage imageMessage, ImageOptions imageOptions) { + this(Collections.singletonList(imageMessage), imageOptions); + } + + public ImagePrompt(String instructions, ImageOptions imageOptions) { + this(new ImageMessage(instructions), imageOptions); + } + + public ImagePrompt(String instructions) { + this(new ImageMessage(instructions), ImageOptionsBuilder.builder().build()); + } + + @Override + public List getInstructions() { + return messages; + } + + @Override + public ImageOptions getOptions() { + return imageModelOptions; + } + + @Override + public String toString() { + return "NewImagePrompt{" + "messages=" + messages + ", imageModelOptions=" + imageModelOptions + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (!(o instanceof ImagePrompt that)) + return false; + return Objects.equals(messages, that.messages) && Objects.equals(imageModelOptions, that.imageModelOptions); + } + + @Override + public int hashCode() { + return Objects.hash(messages, imageModelOptions); + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageResponse.java b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageResponse.java new file mode 100644 index 00000000000..ad6cda7c9e7 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageResponse.java @@ -0,0 +1,75 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.image; + +import java.util.List; +import java.util.Objects; + +import org.springframework.ai.model.ModelResponse; + +public class ImageResponse implements ModelResponse { + + private final ImageResponseMetadata imageResponseMetadata; + + private final List imageGenerations; + + public ImageResponse(List generations) { + this(generations, ImageResponseMetadata.NULL); + } + + public ImageResponse(List generations, ImageResponseMetadata imageResponseMetadata) { + this.imageResponseMetadata = imageResponseMetadata; + this.imageGenerations = List.copyOf(generations); + } + + @Override + public ImageGeneration getResult() { + return imageGenerations.get(0); + } + + @Override + public List getResults() { + return imageGenerations; + } + + @Override + public ImageResponseMetadata getMetadata() { + return imageResponseMetadata; + } + + @Override + public String toString() { + return "ImageResponse{" + "imageResponseMetadata=" + imageResponseMetadata + ", imageGenerations=" + + imageGenerations + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (!(o instanceof ImageResponse that)) + return false; + return Objects.equals(imageResponseMetadata, that.imageResponseMetadata) + && Objects.equals(imageGenerations, that.imageGenerations); + } + + @Override + public int hashCode() { + return Objects.hash(imageResponseMetadata, imageGenerations); + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageResponseMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageResponseMetadata.java new file mode 100644 index 00000000000..7378fedca6e --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageResponseMetadata.java @@ -0,0 +1,30 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.image; + +import org.springframework.ai.model.ResponseMetadata; + +public interface ImageResponseMetadata extends ResponseMetadata { + + ImageResponseMetadata NULL = new ImageResponseMetadata() { + }; + + default Long created() { + return System.currentTimeMillis(); + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelClient.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelClient.java new file mode 100644 index 00000000000..38de2f149b7 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelClient.java @@ -0,0 +1,40 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model; + +/** + * The ModelClient interface provides a generic API for invoking AI models. It is designed + * to handle the interaction with various types of AI models by abstracting the process of + * sending requests and receiving responses. The interface uses Java generics to + * accommodate different types of requests and responses, enhancing flexibility and + * adaptability across different AI model implementations. + * + * @param the generic type of the request to the AI model + * @param the generic type of the response from the AI model + * @author Mark Pollack + * @since 0.8.0 + */ +public interface ModelClient, TRes extends ModelResponse> { + + /** + * Executes a method call to the AI model. + * @param request the request object to be sent to the AI model + * @return the response from the AI model + */ + TRes call(TReq request); + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptions.java new file mode 100644 index 00000000000..9b6e908f4da --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptions.java @@ -0,0 +1,31 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model; + +/** + * Interface representing the customizable options for AI model interactions. This + * interface allows for the specification of various settings and parameters that can + * influence the behavior and output of AI models. It is designed to provide flexibility + * and adaptability in different AI scenarios, ensuring that the AI models can be + * fine-tuned according to specific requirements. + * + * @author Mark Pollack + * @since 0.8.0 + */ +public interface ModelOptions { + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java new file mode 100644 index 00000000000..85c47554ef7 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java @@ -0,0 +1,92 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; + +import java.util.HashMap; +import java.util.Map; +import java.util.stream.Collectors; + +public abstract class ModelOptionsUtils { + + private final static ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + + private ModelOptionsUtils() { + + } + + /** + * Merges the source object into the target object and returns an object represented + * by the given class. The source null values are ignored. + * @param they type of the class to return. + * @param source the source object to merge. + * @param target the target object to merge into. + * @param clazz the class to return. + * @return the merged object represented by the given class. + */ + public static T merge(Object source, Object target, Class clazz) { + Map sourceMap = objectToMap(source); + Map targetMap = objectToMap(target); + + targetMap.putAll(sourceMap.entrySet() + .stream() + .filter(e -> e.getValue() != null) + .collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue()))); + + return mapToClass(targetMap, clazz); + } + + /** + * Converts the given object to a Map. + * @param source the object to convert to a Map. + * @return the converted Map. + */ + public static Map objectToMap(Object source) { + if (source == null) { + return new HashMap<>(); + } + try { + String json = OBJECT_MAPPER.writeValueAsString(source); + return OBJECT_MAPPER.readValue(json, new TypeReference>() { + }); + } + catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + /** + * Converts the given Map to the given class. + * @param the type of the class to return. + * @param source the Map to convert to the given class. + * @param clazz the class to convert the Map to. + * @return the converted class. + */ + public static T mapToClass(Map source, Class clazz) { + try { + String json = OBJECT_MAPPER.writeValueAsString(source); + return OBJECT_MAPPER.readValue(json, clazz); + } + catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelRequest.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelRequest.java new file mode 100644 index 00000000000..0aac6da82c7 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelRequest.java @@ -0,0 +1,44 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model; + +/** + * Interface representing a request to an AI model. This interface encapsulates the + * necessary information required to interact with an AI model, including instructions or + * inputs (of generic type T) and additional model options. It provides a standardized way + * to send requests to AI models, ensuring that all necessary details are included and can + * be easily managed. + * + * @param the type of instructions or input required by the AI model + * @author Mark Pollack + * @since 0.8.0 + */ +public interface ModelRequest { + + /** + * Retrieves the instructions or input required by the AI model. + * @return the instructions or input required by the AI model + */ + T getInstructions(); // required input + + /** + * Retrieves the customizable options for AI model interactions. + * @return the customizable options for AI model interactions + */ + ModelOptions getOptions(); + +} \ No newline at end of file diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelResponse.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelResponse.java new file mode 100644 index 00000000000..5c8a17b5827 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelResponse.java @@ -0,0 +1,52 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model; + +import java.util.List; + +/** + * Interface representing the response received from an AI model. This interface provides + * methods to access the main result or a list of results generated by the AI model, along + * with the response metadata. It serves as a standardized way to encapsulate and manage + * the output from AI models, ensuring easy retrieval and processing of the generated + * information. + * + * @param the type of the result(s) provided by the AI model + * @author Mark Pollack + * @since 0.8.0 + */ +public interface ModelResponse> { + + /** + * Retrieves the result of the AI model. + * @return the result generated by the AI model + */ + T getResult(); + + /** + * Retrieves the list of generated outputs by the AI model. + * @return the list of generated outputs + */ + List getResults(); + + /** + * Retrieves the response metadata associated with the AI model's response. + * @return the response metadata + */ + ResponseMetadata getMetadata(); + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelResult.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelResult.java new file mode 100644 index 00000000000..5a5613a7280 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelResult.java @@ -0,0 +1,43 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model; + +/** + * This interface provides methods to access the main output of the AI model and the + * metadata associated with this result. It is designed to offer a standardized and + * comprehensive way to handle and interpret the outputs generated by AI models, catering + * to diverse AI applications and use cases. + * + * @param the type of the output generated by the AI model + * @author Mark Pollack + * @since 0.8.0 + */ +public interface ModelResult { + + /** + * Retrieves the output generated by the AI model. + * @return the output generated by the AI model + */ + T getOutput(); + + /** + * Retrieves the metadata associated with the result of an AI model. + * @return the metadata associated with the result + */ + ResultMetadata getMetadata(); + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ResponseMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ResponseMetadata.java new file mode 100644 index 00000000000..14af864bb49 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ResponseMetadata.java @@ -0,0 +1,31 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model; + +/** + * Interface representing metadata associated with an AI model's response. This interface + * is designed to provide additional information about the generative response from an AI + * model, including processing details and model-specific data. It serves as a value + * object within the core domain, enhancing the understanding and management of AI model + * responses in various applications. + * + * @author Mark Pollack + * @since 0.8.0 + */ +public interface ResponseMetadata { + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ResultMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ResultMetadata.java new file mode 100644 index 00000000000..78d5f7f6a91 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ResultMetadata.java @@ -0,0 +1,31 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model; + +/** + * Interface representing metadata associated with the results of an AI model. This + * interface focuses on providing additional context and insights into the results + * generated by AI models. It could include information like computation time, model + * version, or other relevant details that enhance understanding and management of AI + * model outputs in various applications. + * + * @author Mark Pollack + * @since 0.8.0 + */ +public interface ResultMetadata { + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/package-info.java b/spring-ai-core/src/main/java/org/springframework/ai/model/package-info.java new file mode 100644 index 00000000000..12eaa53b400 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/package-info.java @@ -0,0 +1,11 @@ +/** + * Provides a set of interfaces and classes for a generic API designed to interact with + * various AI models. This package includes interfaces for handling AI model calls, + * requests, responses, results, and associated metadata. It is designed to offer a + * flexible and adaptable framework for interacting with different types of AI models, + * abstracting the complexities involved in model invocation and result processing. The + * use of generics enhances the API's capability to work with a wide range of models, + * ensuring a broad applicability across diverse AI scenarios. + * + */ +package org.springframework.ai.model; \ No newline at end of file diff --git a/spring-ai-core/src/main/java/org/springframework/ai/parser/FormatProvider.java b/spring-ai-core/src/main/java/org/springframework/ai/parser/FormatProvider.java index 2770811506a..8d0cb76b099 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/parser/FormatProvider.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/parser/FormatProvider.java @@ -18,7 +18,7 @@ /** * Implementations of this interface provides instructions for how the output of a - * language model should be formatted. + * language generative should be formatted. * * @author Mark Pollack */ @@ -26,7 +26,7 @@ public interface FormatProvider { /** * @return Returns a string containing instructions for how the output of a language - * model should be formatted. + * generative should be formatted. */ String getFormat(); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/transformer/KeywordMetadataEnricher.java b/spring-ai-core/src/main/java/org/springframework/ai/transformer/KeywordMetadataEnricher.java index 9387be7444b..40280f19269 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/transformer/KeywordMetadataEnricher.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/transformer/KeywordMetadataEnricher.java @@ -22,12 +22,12 @@ import org.springframework.ai.chat.ChatClient; import org.springframework.ai.document.Document; import org.springframework.ai.document.DocumentTransformer; -import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.PromptTemplate; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.util.Assert; /** - * Keyword extractor that uses model to extract 'excerpt_keywords' metadata field. + * Keyword extractor that uses generative to extract 'excerpt_keywords' metadata field. * * @author Christian Tzolov */ @@ -65,7 +65,7 @@ public List apply(List documents) { var template = new PromptTemplate(String.format(KEYWORDS_TEMPLATE, keywordCount)); Prompt prompt = template.create(Map.of(CONTEXT_STR_PLACEHOLDER, document.getContent())); - String keywords = this.chatClient.generate(prompt).getGeneration().getContent(); + String keywords = this.chatClient.call(prompt).getResult().getOutput().getContent(); document.getMetadata().putAll(Map.of(EXCERPT_KEYWORDS_METADATA_KEY, keywords)); } return documents; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/transformer/SummaryMetadataEnricher.java b/spring-ai-core/src/main/java/org/springframework/ai/transformer/SummaryMetadataEnricher.java index e5162e1b776..f0a607dbda0 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/transformer/SummaryMetadataEnricher.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/transformer/SummaryMetadataEnricher.java @@ -25,14 +25,14 @@ import org.springframework.ai.document.Document; import org.springframework.ai.document.DocumentTransformer; import org.springframework.ai.document.MetadataMode; -import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.PromptTemplate; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; /** - * Title extractor with adjacent sharing that uses model to extract 'section_summary', - * 'prev_section_summary', 'next_section_summary' metadata fields. + * Title extractor with adjacent sharing that uses generative to extract + * 'section_summary', 'prev_section_summary', 'next_section_summary' metadata fields. * * @author Christian Tzolov */ @@ -102,7 +102,7 @@ public List apply(List documents) { Prompt prompt = new PromptTemplate(this.summaryTemplate) .create(Map.of(CONTEXT_STR_PLACEHOLDER, documentContext)); - documentSummaries.add(this.chatClient.generate(prompt).getGeneration().getContent()); + documentSummaries.add(this.chatClient.call(prompt).getResult().getOutput().getContent()); } for (int i = 0; i < documentSummaries.size(); i++) { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/Filter.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/Filter.java index 02c18f6b26a..04f6adaecf2 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/Filter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/Filter.java @@ -17,12 +17,13 @@ package org.springframework.ai.vectorstore.filter; /** - * Portable runtime model for metadata filter expressions. This generic model is used to - * define store agnostic filter expressions than later can be converted into vector-store - * specific, native, expressions. + * Portable runtime generative for metadata filter expressions. This generic generative is + * used to define store agnostic filter expressions than later can be converted into + * vector-store specific, native, expressions. * - * The expression model supports constant comparison {@code (e.g. ==, !=, <, <=, >, >=) }, - * IN/NON-IN checks and AND and OR to compose multiple expressions. + * The expression generative supports constant comparison + * {@code (e.g. ==, !=, <, <=, >, >=) }, IN/NON-IN checks and AND and OR to compose + * multiple expressions. * * For example: * diff --git a/spring-ai-core/src/main/resources/embedding/embedding-model-dimensions.properties b/spring-ai-core/src/main/resources/embedding/embedding-model-dimensions.properties index 89f3b63cd3a..849a8ceacf4 100644 --- a/spring-ai-core/src/main/resources/embedding/embedding-model-dimensions.properties +++ b/spring-ai-core/src/main/resources/embedding/embedding-model-dimensions.properties @@ -1,4 +1,4 @@ -# Map of embedding model names and their dimensions +# Map of embedding generative names and their dimensions # OpenAI text-embedding-ada-002=1536 text-similarity-ada-001=1024 diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/ChatClientTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/ChatClientTests.java index 8a796a06b93..1a85939955c 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/ChatClientTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/ChatClientTests.java @@ -21,20 +21,13 @@ import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isA; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.doCallRealMethod; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyNoMoreInteractions; - -import java.util.Collections; +import static org.mockito.Mockito.*; import org.junit.jupiter.api.Test; import org.mockito.Mockito; -import org.springframework.ai.prompt.Prompt; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.prompt.Prompt; /** * Unit Tests for {@link ChatClient}. @@ -51,10 +44,23 @@ void generateWithStringCallsGenerateWithPromptAndReturnsResponseCorrectly() { String responseMessage = "All your bases are belong to us"; ChatClient mockClient = Mockito.mock(ChatClient.class); - Generation generation = spy(new Generation(responseMessage)); - ChatResponse response = spy(new ChatResponse(Collections.singletonList(generation))); - doCallRealMethod().when(mockClient).generate(anyString()); + AssistantMessage mockAssistantMessage = Mockito.mock(AssistantMessage.class); + when(mockAssistantMessage.getContent()).thenReturn(responseMessage); + + // Create a mock Generation + Generation generation = Mockito.mock(Generation.class); + when(generation.getOutput()).thenReturn(mockAssistantMessage); + + // Create a mock ChatResponse with the mock Generation + ChatResponse response = Mockito.mock(ChatResponse.class); + when(response.getResult()).thenReturn(generation); + + // Generation generation = spy(new Generation(responseMessage)); + // ChatResponse response = spy(new + // ChatResponse(Collections.singletonList(generation))); + + doCallRealMethod().when(mockClient).call(anyString()); doAnswer(invocationOnMock -> { @@ -65,14 +71,15 @@ void generateWithStringCallsGenerateWithPromptAndReturnsResponseCorrectly() { return response; - }).when(mockClient).generate(any(Prompt.class)); + }).when(mockClient).call(any(Prompt.class)); - assertThat(mockClient.generate(userMessage)).isEqualTo(responseMessage); + assertThat(mockClient.call(userMessage)).isEqualTo(responseMessage); - verify(mockClient, times(1)).generate(eq(userMessage)); - verify(mockClient, times(1)).generate(isA(Prompt.class)); - verify(response, times(1)).getGeneration(); - verify(generation, times(1)).getContent(); + verify(mockClient, times(1)).call(eq(userMessage)); + verify(mockClient, times(1)).call(isA(Prompt.class)); + verify(response, times(1)).getResult(); + verify(generation, times(1)).getOutput(); + verify(mockAssistantMessage, times(1)).getContent(); verifyNoMoreInteractions(mockClient, generation, response); } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/embedding/EmbeddingUtilTests.java b/spring-ai-core/src/test/java/org/springframework/ai/embedding/EmbeddingUtilTests.java index 45d1e68c446..c8f88d0a1e6 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/embedding/EmbeddingUtilTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/embedding/EmbeddingUtilTests.java @@ -18,6 +18,7 @@ import java.util.List; +import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.params.ParameterizedTest; diff --git a/spring-ai-core/src/test/java/org/springframework/ai/metadata/PromptMetadataTests.java b/spring-ai-core/src/test/java/org/springframework/ai/metadata/PromptMetadataTests.java index 792b093842a..aef94a6b030 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/metadata/PromptMetadataTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/metadata/PromptMetadataTests.java @@ -22,7 +22,8 @@ import org.junit.jupiter.api.Test; -import org.springframework.ai.metadata.PromptMetadata.PromptFilterMetadata; +import org.springframework.ai.chat.metadata.PromptMetadata; +import org.springframework.ai.chat.metadata.PromptMetadata.PromptFilterMetadata; /** * Unit Tests for {@link PromptMetadata}. diff --git a/spring-ai-core/src/test/java/org/springframework/ai/metadata/UsageTests.java b/spring-ai-core/src/test/java/org/springframework/ai/metadata/UsageTests.java index 0507d4c7f16..0e44b63d64f 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/metadata/UsageTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/metadata/UsageTests.java @@ -25,6 +25,7 @@ import static org.mockito.Mockito.verifyNoMoreInteractions; import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.metadata.Usage; /** * Unit Tests for {@link Usage}. diff --git a/spring-ai-core/src/test/java/org/springframework/ai/prompt/PromptTemplateTest.java b/spring-ai-core/src/test/java/org/springframework/ai/prompt/PromptTemplateTest.java index 74abea55eb2..52dcd7712a6 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/prompt/PromptTemplateTest.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/prompt/PromptTemplateTest.java @@ -2,6 +2,7 @@ import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.core.io.InputStreamResource; import org.springframework.core.io.Resource; @@ -18,17 +19,18 @@ public class PromptTemplateTest { @Test public void testRender() { - // Create a map with string keys and object values to serve as a model for testing + // Create a map with string keys and object values to serve as a generative for + // testing Map model = new HashMap<>(); model.put("key1", "value1"); model.put("key2", true); model.put("key3", 100); - // Create a simple template with placeholders for keys in the model + // Create a simple template with placeholders for keys in the generative String template = "This is a {key1}, it is {key2}, and it costs {key3}"; PromptTemplate promptTemplate = new PromptTemplate(template, model); - // The expected result after rendering the template with the model + // The expected result after rendering the template with the generative String expected = "This is a value1, it is true, and it costs 100"; String result = promptTemplate.render(); @@ -44,7 +46,8 @@ public void testRender() { @Disabled("Need to improve PromptTemplate to better handle Resource toString and tracking with 'dynamicModel' for underlying StringTemplate") @Test public void testRenderResource() throws Exception { - // Create a map with string keys and object values to serve as a model for testing + // Create a map with string keys and object values to serve as a generative for + // testing Map model = new HashMap<>(); model.put("key1", "value1"); model.put("key2", true); @@ -55,11 +58,11 @@ public void testRenderResource() throws Exception { model.put("key3", resource); - // Create a simple template with placeholders for keys in the model + // Create a simple template with placeholders for keys in the generative String template = "{key1}, {key2}, {key3}"; PromptTemplate promptTemplate = new PromptTemplate(template, model); - // The expected result after rendering the template with the model + // The expected result after rendering the template with the generative String expected = "value1, true, it costs 100"; String result = promptTemplate.render(); @@ -69,11 +72,12 @@ public void testRenderResource() throws Exception { @Test public void testRenderFailure() { - // Create a map with string keys and object values to serve as a model for testing + // Create a map with string keys and object values to serve as a generative for + // testing Map model = new HashMap<>(); model.put("key1", "value1"); - // Create a simple template that includes a key not present in the model + // Create a simple template that includes a key not present in the generative String template = "This is a {key2}!"; PromptTemplate promptTemplate = new PromptTemplate(template, model); diff --git a/spring-ai-core/src/test/java/org/springframework/ai/prompt/PromptTests.java b/spring-ai-core/src/test/java/org/springframework/ai/prompt/PromptTests.java index 22d46555f1a..b08ed9aec70 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/prompt/PromptTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/prompt/PromptTests.java @@ -18,6 +18,9 @@ import org.assertj.core.api.Assertions; import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; import java.util.HashMap; import java.util.Map; @@ -54,7 +57,7 @@ void newApiPlaygroundTests() { // to have access to Messages Prompt prompt = pt.create(model); assertThat(prompt.getContents()).isNotNull(); - assertThat(prompt.getMessages()).isNotEmpty().hasSize(1); + assertThat(prompt.getInstructions()).isNotEmpty().hasSize(1); System.out.println(prompt.getContents()); String systemTemplate = "You are a helpful assistant that translates {input_language} to {output_language}."; @@ -86,7 +89,7 @@ void newApiPlaygroundTests() { // ChatPromptTemplate chatPromptTemplate = new ChatPromptTemplate(systemPrompt, // humanPrompt); - // Prompt chatPrompt chatPromptTemplate.create(model); + // Prompt chatPrompt chatPromptTemplate.create(generative); } diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/onnx.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/onnx.adoc index 722fb3b7f7d..d8f87898c67 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/onnx.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/onnx.adoc @@ -22,7 +22,7 @@ python3 -m venv venv source ./venv/bin/activate (venv) pip install --upgrade pip (venv) pip install optimum onnx onnxruntime -(venv) optimum-cli export onnx --model sentence-transformers/all-MiniLM-L6-v2 onnx-output-folder +(venv) optimum-cli export onnx --generative sentence-transformers/all-MiniLM-L6-v2 onnx-output-folder ---- The snippet exports the https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2[sentence-transformers/all-MiniLM-L6-v2] transformer into the `onnx-output-folder` folder. Later includes the `tokenizer.json` and `model.onnx` files used by the embedding client. diff --git a/spring-ai-spring-boot-autoconfigure/pom.xml b/spring-ai-spring-boot-autoconfigure/pom.xml index 97d136c5d57..c188fa663c5 100644 --- a/spring-ai-spring-boot-autoconfigure/pom.xml +++ b/spring-ai-spring-boot-autoconfigure/pom.xml @@ -156,6 +156,14 @@ true + + + org.springframework.ai + spring-ai-stability-ai + ${project.parent.version} + true + + org.springframework.ai diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiChatProperties.java index 57d3dac2802..fadcafe1b79 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiChatProperties.java @@ -34,11 +34,11 @@ public class AzureOpenAiChatProperties { /** * An alternative to sampling with temperature called nucleus sampling. This value - * causes the model to consider the results of tokens with the provided probability - * mass. As an example, a value of 0.15 will cause only the tokens comprising the top - * 15% of probability mass to be considered. It is not recommended to modify - * temperature and top_p for the same completions request as the interaction of these - * two settings is difficult to predict. + * causes the generative to consider the results of tokens with the provided + * probability mass. As an example, a value of 0.15 will cause only the tokens + * comprising the top 15% of probability mass to be considered. It is not recommended + * to modify temperature and top_p for the same completions request as the interaction + * of these two settings is difficult to predict. */ private Double topP; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiEmbeddingProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiEmbeddingProperties.java index 7ba71f1a96c..c7ddee8537c 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiEmbeddingProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiEmbeddingProperties.java @@ -24,7 +24,7 @@ public class AzureOpenAiEmbeddingProperties { public static final String CONFIG_PREFIX = "spring.ai.azure.openai.embedding"; /** - * The text embedding model to use for the embedding client. + * The text embedding generative to use for the embedding client. */ private String model = "text-embedding-ada-002"; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatProperties.java index 5b22fff7b60..0b3d36f9857 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatProperties.java @@ -39,23 +39,24 @@ public class BedrockAnthropicChatProperties { private boolean enabled = false; /** - * The model id to use. See the {@link AnthropicChatModel} for the supported models. + * The generative id to use. See the {@link AnthropicChatModel} for the supported + * models. */ private String model = AnthropicChatModel.CLAUDE_V2.id(); /** * Controls the randomness of the output. Values can range over [0.0,1.0], inclusive. * A value closer to 1.0 will produce responses that are more varied, while a value - * closer to 0.0 will typically result in less surprising responses from the model. - * This value specifies default to be used by the backend while making the call to the - * model. + * closer to 0.0 will typically result in less surprising responses from the + * generative. This value specifies default to be used by the backend while making the + * call to the generative. */ private Float temperature = 0.7f; /** - * The maximum cumulative probability of tokens to consider when sampling. The model - * uses combined Top-k and nucleus sampling. Nucleus sampling considers the smallest - * set of tokens whose probability sum is at least topP. + * The maximum cumulative probability of tokens to consider when sampling. The + * generative uses combined Top-k and nucleus sampling. Nucleus sampling considers the + * smallest set of tokens whose probability sum is at least topP. */ private Float topP = null; @@ -68,19 +69,19 @@ public class BedrockAnthropicChatProperties { private Integer maxTokensToSample = 300; /** - * Specify the number of token choices the model uses to generate the next token. + * Specify the number of token choices the generative uses to generate the next token. */ private Integer topK = 10; /** - * Configure up to four sequences that the model recognizes. After a stop sequence, - * the model stops generating further tokens. The returned text doesn't contain the - * stop sequence. + * Configure up to four sequences that the generative recognizes. After a stop + * sequence, the generative stops generating further tokens. The returned text doesn't + * contain the stop sequence. */ private List stopSequences = List.of("\n\nHuman:"); /** - * The version of the model to use. The default value is bedrock-2023-05-31. + * The version of the generative to use. The default value is bedrock-2023-05-31. */ private String anthropicVersion = AnthropicChatBedrockApi.DEFAULT_ANTHROPIC_VERSION; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatProperties.java index 97d58bbf3de..0985d55b109 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatProperties.java @@ -40,7 +40,7 @@ public class BedrockCohereChatProperties { private boolean enabled = false; /** - * Bedrock Cohere Chat model name. Defaults to 'cohere-command-v14'. + * Bedrock Cohere Chat generative name. Defaults to 'cohere-command-v14'. */ private String model = CohereChatBedrockApi.CohereChatModel.COHERE_COMMAND_V14.id(); @@ -52,14 +52,14 @@ public class BedrockCohereChatProperties { /** * (optional) The maximum cumulative probability of tokens to consider when sampling. - * The model uses combined Top-k and nucleus sampling. Nucleus sampling considers the - * smallest set of tokens whose probability sum is at least topP. + * The generative uses combined Top-k and nucleus sampling. Nucleus sampling considers + * the smallest set of tokens whose probability sum is at least topP. */ private Float topP; /** - * (optional) Specify the number of token choices the model uses to generate the next - * token. + * (optional) Specify the number of token choices the generative uses to generate the + * next token. */ private Integer topK; @@ -69,9 +69,9 @@ public class BedrockCohereChatProperties { private Integer maxTokens; /** - * (optional) Configure up to four sequences that the model recognizes. After a stop - * sequence, the model stops generating further tokens. The returned text doesn't - * contain the stop sequence. + * (optional) Configure up to four sequences that the generative recognizes. After a + * stop sequence, the generative stops generating further tokens. The returned text + * doesn't contain the stop sequence. */ private List stopSequences; @@ -81,19 +81,19 @@ public class BedrockCohereChatProperties { private ReturnLikelihoods returnLikelihoods; /** - * (optional) The maximum number of generations that the model should return. + * (optional) The maximum number of generations that the generative should return. */ private Integer numGenerations; /** - * LogitBias prevents the model from generating unwanted tokens or incentivize the - * model to include desired tokens. The token likelihoods. + * LogitBias prevents the generative from generating unwanted tokens or incentivize + * the generative to include desired tokens. The token likelihoods. */ private String logitBiasToken; /** - * LogitBias prevents the model from generating unwanted tokens or incentivize the - * model to include desired tokens. A float between -10 and 10. + * LogitBias prevents the generative from generating unwanted tokens or incentivize + * the generative to include desired tokens. A float between -10 and 10. */ private Float logitBiasBias; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingProperties.java index 7197d99ed8b..fc1c56996a3 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingProperties.java @@ -38,7 +38,8 @@ public class BedrockCohereEmbeddingProperties { private boolean enabled = false; /** - * Bedrock Cohere Embedding model name. Defaults to 'cohere.embed-multilingual-v3'. + * Bedrock Cohere Embedding generative name. Defaults to + * 'cohere.embed-multilingual-v3'. */ private String model = CohereEmbeddingModel.COHERE_EMBED_MULTILINGUAL_V1.id(); diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama2/BedrockLlama2ChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama2/BedrockLlama2ChatProperties.java index 4c7b05ebef6..9d2c3d7e28f 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama2/BedrockLlama2ChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama2/BedrockLlama2ChatProperties.java @@ -38,27 +38,27 @@ public class BedrockLlama2ChatProperties { /** * Controls the randomness of the output. Values can range over [0.0,1.0], inclusive. * A value closer to 1.0 will produce responses that are more varied, while a value - * closer to 0.0 will typically result in less surprising responses from the model. - * This value specifies default to be used by the backend while making the call to the - * model. + * closer to 0.0 will typically result in less surprising responses from the + * generative. This value specifies default to be used by the backend while making the + * call to the generative. */ private Float temperature = 0.7f; /** - * The maximum cumulative probability of tokens to consider when sampling. The model - * uses combined Top-k and nucleus sampling. Nucleus sampling considers the smallest - * set of tokens whose probability sum is at least topP. + * The maximum cumulative probability of tokens to consider when sampling. The + * generative uses combined Top-k and nucleus sampling. Nucleus sampling considers the + * smallest set of tokens whose probability sum is at least topP. */ private Float topP = null; /** - * Specify the maximum number of tokens to use in the generated response. The model - * truncates the response once the generated text exceeds maxGenLen. + * Specify the maximum number of tokens to use in the generated response. The + * generative truncates the response once the generated text exceeds maxGenLen. */ private Integer maxGenLen = 300; /** - * The model id to use. See the {@link Llama2ChatModel} for the supported models. + * The generative id to use. See the {@link Llama2ChatModel} for the supported models. */ private String model = Llama2ChatModel.LLAMA2_70B_CHAT_V1.id(); diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatProperties.java index f72adf15d47..60d2ce71291 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatProperties.java @@ -38,7 +38,7 @@ public class BedrockTitanChatProperties { private boolean enabled = false; /** - * Bedrock Titan Chat model name. Defaults to 'amazon.titan-text-express-v1'. + * Bedrock Titan Chat generative name. Defaults to 'amazon.titan-text-express-v1'. */ private String model = TitanChatModel.TITAN_TEXT_EXPRESS_V1.id(); @@ -50,8 +50,8 @@ public class BedrockTitanChatProperties { /** * (optional) The maximum cumulative probability of tokens to consider when sampling. - * The model uses combined Top-k and nucleus sampling. Nucleus sampling considers the - * smallest set of tokens whose probability sum is at least topP. + * The generative uses combined Top-k and nucleus sampling. Nucleus sampling considers + * the smallest set of tokens whose probability sum is at least topP. */ private Float topP; @@ -61,9 +61,9 @@ public class BedrockTitanChatProperties { private Integer maxTokenCount; /** - * (optional) Configure up to four sequences that the model recognizes. After a stop - * sequence, the model stops generating further tokens. The returned text doesn't - * contain the stop sequence. + * (optional) Configure up to four sequences that the generative recognizes. After a + * stop sequence, the generative stops generating further tokens. The returned text + * doesn't contain the stop sequence. */ private List stopSequences; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingProperties.java index f74026c813c..1a6b1e5c141 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingProperties.java @@ -37,7 +37,7 @@ public class BedrockTitanEmbeddingProperties { private boolean enabled = false; /** - * Bedrock Titan Embedding model name. Defaults to 'amazon.titan-embed-image-v1'. + * Bedrock Titan Embedding generative name. Defaults to 'amazon.titan-embed-image-v1'. */ private String model = TitanEmbeddingModel.TITAN_EMBED_IMAGE_V1.id(); diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaChatProperties.java index ed8762fc3de..1b07aca6f7b 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaChatProperties.java @@ -31,14 +31,14 @@ public class OllamaChatProperties { public static final String CONFIG_PREFIX = "spring.ai.ollama.chat"; /** - * Ollama Chat model name. Defaults to 'llama2'. + * Ollama Chat generative name. Defaults to 'llama2'. */ private String model = "llama2"; /** - * Client lever Ollama options. Use this property to configure model temperature, topK - * and topP and alike parameters. The null values are ignored defaulting to the - * model's defaults. + * Client lever Ollama options. Use this property to configure generative temperature, + * topK and topP and alike parameters. The null values are ignored defaulting to the + * generative's defaults. */ private OllamaOptions options = new OllamaOptions(); diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingProperties.java index a4dd65a7ac1..f4039bf57f5 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingProperties.java @@ -31,14 +31,14 @@ public class OllamaEmbeddingProperties { public static final String CONFIG_PREFIX = "spring.ai.ollama.embedding"; /** - * Ollama Embedding model name. Defaults to 'llama2'. + * Ollama Embedding generative name. Defaults to 'llama2'. */ private String model = "llama2"; /** - * Client lever Ollama options. Use this property to configure model temperature, topK - * and topP and alike parameters. The null values are ignored defaulting to the - * model's defaults. + * Client lever Ollama options. Use this property to configure generative temperature, + * topK and topP and alike parameters. The null values are ignored defaulting to the + * generative's defaults. */ private OllamaOptions options = new OllamaOptions(); diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiImageAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiImageAutoConfiguration.java new file mode 100644 index 00000000000..485c2eb841f --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiImageAutoConfiguration.java @@ -0,0 +1,33 @@ +package org.springframework.ai.autoconfigure.stabilityai; + +import org.springframework.ai.autoconfigure.NativeHints; +import org.springframework.ai.stabilityai.StabilityAiImageClient; +import org.springframework.ai.stabilityai.api.StabilityAiApi; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.ImportRuntimeHints; + +@AutoConfiguration +@ConditionalOnClass(StabilityAiApi.class) +@EnableConfigurationProperties({ StabilityAiProperties.class }) +@ImportRuntimeHints(NativeHints.class) +public class StabilityAiImageAutoConfiguration { + + @Bean + @ConditionalOnMissingBean + public StabilityAiApi stabilityAiApi(StabilityAiProperties stabilityAiProperties) { + return new StabilityAiApi(stabilityAiProperties.getApiKey(), stabilityAiProperties.getBaseUrl(), + stabilityAiProperties.getOptions().getModel()); + } + + @Bean + @ConditionalOnMissingBean + public StabilityAiImageClient stabilityAiImageClient(StabilityAiApi stabilityAiApi, + StabilityAiProperties stabilityAiProperties) { + return new StabilityAiImageClient(stabilityAiApi, stabilityAiProperties.getOptions()); + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiProperties.java new file mode 100644 index 00000000000..5955999dbf2 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiProperties.java @@ -0,0 +1,44 @@ +package org.springframework.ai.autoconfigure.stabilityai; + +import org.springframework.ai.stabilityai.api.StabilityAiApi; +import org.springframework.ai.stabilityai.api.StabilityAiImageOptions; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.context.properties.NestedConfigurationProperty; + +@ConfigurationProperties(StabilityAiProperties.CONFIG_PREFIX) +public class StabilityAiProperties { + + public static final String CONFIG_PREFIX = "spring.ai.stabilityai"; + + private String apiKey; + + private String baseUrl = StabilityAiApi.DEFAULT_BASE_URL; + + @NestedConfigurationProperty + private StabilityAiImageOptions options; + + public String getApiKey() { + return apiKey; + } + + public void setApiKey(String apiKey) { + this.apiKey = apiKey; + } + + public String getBaseUrl() { + return baseUrl; + } + + public void setBaseUrl(String baseUrl) { + this.baseUrl = baseUrl; + } + + public StabilityAiImageOptions getOptions() { + return options; + } + + public void setOptions(StabilityAiImageOptions options) { + this.options = options; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/transformers/TransformersEmbeddingClientProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/transformers/TransformersEmbeddingClientProperties.java index 8bc236270e7..6b1fa816df1 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/transformers/TransformersEmbeddingClientProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/transformers/TransformersEmbeddingClientProperties.java @@ -39,7 +39,7 @@ public class TransformersEmbeddingClientProperties { public static final String CONFIG_PREFIX = "spring.ai.embedding.transformer"; public static final String DEFAULT_CACHE_DIRECTORY = new File(System.getProperty("java.io.tmpdir"), - "spring-ai-onnx-model") + "spring-ai-onnx-generative") .getAbsolutePath(); /** @@ -91,7 +91,7 @@ public static class Cache { /** * Resource cache directory. Used to cache remote resources, such as the ONNX * models, to the local file system. Applicable only for cache.enabled == true. - * Defaults to {java.io.tmpdir}/spring-ai-onnx-model. + * Defaults to {java.io.tmpdir}/spring-ai-onnx-generative. */ private String directory = DEFAULT_CACHE_DIRECTORY; @@ -125,7 +125,7 @@ public Cache getCache() { public static class Onnx { /** - * Existing, pre-trained ONNX model. Commonly exported from + * Existing, pre-trained ONNX generative. Commonly exported from * https://sbert.net/docs/pretrained_models.html. Defaults to * sentence-transformers/all-MiniLM-L6-v2. */ diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/VertexAiChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/VertexAiChatProperties.java index a53f6567098..67b7ba714c3 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/VertexAiChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/VertexAiChatProperties.java @@ -27,16 +27,16 @@ public class VertexAiChatProperties { /** * Controls the randomness of the output. Values can range over [0.0,1.0], inclusive. * A value closer to 1.0 will produce responses that are more varied, while a value - * closer to 0.0 will typically result in less surprising responses from the model. - * This value specifies default to be used by the backend while making the call to the - * model. + * closer to 0.0 will typically result in less surprising responses from the + * generative. This value specifies default to be used by the backend while making the + * call to the generative. */ private Float temperature = 0.7f; /** - * The maximum cumulative probability of tokens to consider when sampling. The model - * uses combined Top-k and nucleus sampling. Nucleus sampling considers the smallest - * set of tokens whose probability sum is at least topP. + * The maximum cumulative probability of tokens to consider when sampling. The + * generative uses combined Top-k and nucleus sampling. Nucleus sampling considers the + * smallest set of tokens whose probability sum is at least topP. */ private Float topP = null; @@ -47,14 +47,14 @@ public class VertexAiChatProperties { private Integer candidateCount = 1; /** - * The maximum number of tokens to consider when sampling. The model uses combined - * Top-k and nucleus sampling. Top-k sampling considers the set of topK most probable - * tokens. + * The maximum number of tokens to consider when sampling. The generative uses + * combined Top-k and nucleus sampling. Top-k sampling considers the set of topK most + * probable tokens. */ private Integer topK = 20; /** - * Vertex AI PaLM API model name. Defaults to chat-bison-001 + * Vertex AI PaLM API generative name. Defaults to chat-bison-001 */ private String model = VertexAiApi.DEFAULT_GENERATE_MODEL; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/VertexAiEmbeddingProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/VertexAiEmbeddingProperties.java index dcdb8d140f4..433e9039248 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/VertexAiEmbeddingProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/VertexAiEmbeddingProperties.java @@ -25,7 +25,7 @@ public class VertexAiEmbeddingProperties { public static final String CONFIG_PREFIX = "spring.ai.vertex.ai.embedding"; /** - * Vertex AI PaLM API embedding model name. Defaults to embedding-gecko-001. + * Vertex AI PaLM API embedding generative name. Defaults to embedding-gecko-001. */ private String model = VertexAiApi.DEFAULT_EMBEDDING_MODEL; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports b/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports index 3786b5377fc..240fdc4db64 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports +++ b/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports @@ -1,5 +1,6 @@ org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiAutoConfiguration +org.springframework.ai.autoconfigure.stabilityai.StabilityAiImageAutoConfiguration org.springframework.ai.autoconfigure.transformers.TransformersEmbeddingClientAutoConfiguration org.springframework.ai.autoconfigure.huggingface.HuggingfaceChatAutoConfiguration org.springframework.ai.autoconfigure.vertexai.VertexAiAutoConfiguration diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java index 9caf7368bba..4184358ab7d 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java @@ -22,6 +22,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.chat.messages.AssistantMessage; import reactor.core.publisher.Flux; import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiAutoConfiguration; @@ -30,10 +31,10 @@ import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.Generation; import org.springframework.ai.embedding.EmbeddingResponse; -import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.SystemPromptTemplate; -import org.springframework.ai.prompt.messages.Message; -import org.springframework.ai.prompt.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -78,8 +79,8 @@ public class AzureOpenAiAutoConfigurationIT { public void chatCompletion() { contextRunner.run(context -> { AzureOpenAiChatClient chatClient = context.getBean(AzureOpenAiChatClient.class); - ChatResponse response = chatClient.generate(new Prompt(List.of(userMessage, systemMessage))); - assertThat(response.getGeneration().getContent()).contains("Blackbeard"); + ChatResponse response = chatClient.call(new Prompt(List.of(userMessage, systemMessage))); + assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); }); } @@ -95,9 +96,10 @@ public void chatCompletionStreaming() { assertThat(responses.size()).isGreaterThan(1); String stitchedResponseContent = responses.stream() - .map(ChatResponse::getGenerations) + .map(ChatResponse::getResults) .flatMap(List::stream) - .map(Generation::getContent) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) .collect(Collectors.joining()); assertThat(stitchedResponseContent).contains("Blackbeard"); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatAutoConfigurationIT.java index 416bb867695..8f2a2eaff84 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatAutoConfigurationIT.java @@ -22,6 +22,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.chat.messages.AssistantMessage; import reactor.core.publisher.Flux; import software.amazon.awssdk.regions.Region; @@ -30,10 +31,10 @@ import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatModel; import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.Generation; -import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.SystemPromptTemplate; -import org.springframework.ai.prompt.messages.Message; -import org.springframework.ai.prompt.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -70,8 +71,8 @@ public class BedrockAnthropicChatAutoConfigurationIT { public void chatCompletion() { contextRunner.run(context -> { BedrockAnthropicChatClient anthropicChatClient = context.getBean(BedrockAnthropicChatClient.class); - ChatResponse response = anthropicChatClient.generate(new Prompt(List.of(userMessage, systemMessage))); - assertThat(response.getGeneration().getContent()).contains("Blackbeard"); + ChatResponse response = anthropicChatClient.call(new Prompt(List.of(userMessage, systemMessage))); + assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); }); } @@ -88,9 +89,10 @@ public void chatCompletionStreaming() { assertThat(responses.size()).isGreaterThan(2); String stitchedResponseContent = responses.stream() - .map(ChatResponse::getGenerations) + .map(ChatResponse::getResults) .flatMap(List::stream) - .map(Generation::getContent) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) .collect(Collectors.joining()); assertThat(stitchedResponseContent).contains("Blackbeard"); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatAutoConfigurationIT.java index f80a9b78dd1..b40d337afba 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatAutoConfigurationIT.java @@ -23,6 +23,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.springframework.ai.chat.ChatResponse; +import org.springframework.ai.chat.messages.AssistantMessage; import reactor.core.publisher.Flux; import software.amazon.awssdk.regions.Region; @@ -32,10 +33,10 @@ import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatRequest.ReturnLikelihoods; import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatRequest.Truncate; import org.springframework.ai.chat.Generation; -import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.SystemPromptTemplate; -import org.springframework.ai.prompt.messages.Message; -import org.springframework.ai.prompt.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -72,8 +73,8 @@ public class BedrockCohereChatAutoConfigurationIT { public void chatCompletion() { contextRunner.run(context -> { BedrockCohereChatClient cohereChatClient = context.getBean(BedrockCohereChatClient.class); - ChatResponse response = cohereChatClient.generate(new Prompt(List.of(userMessage, systemMessage))); - assertThat(response.getGeneration().getContent()).contains("Blackbeard"); + ChatResponse response = cohereChatClient.call(new Prompt(List.of(userMessage, systemMessage))); + assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); }); } @@ -90,9 +91,10 @@ public void chatCompletionStreaming() { assertThat(responses.size()).isGreaterThan(2); String stitchedResponseContent = responses.stream() - .map(ChatResponse::getGenerations) + .map(ChatResponse::getResults) .flatMap(List::stream) - .map(Generation::getContent) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) .collect(Collectors.joining()); assertThat(stitchedResponseContent).contains("Blackbeard"); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/llama2/BedrockLlama2ChatAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/llama2/BedrockLlama2ChatAutoConfigurationIT.java index 7ecefc6ba74..97435c61d50 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/llama2/BedrockLlama2ChatAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/llama2/BedrockLlama2ChatAutoConfigurationIT.java @@ -23,6 +23,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.springframework.ai.chat.ChatResponse; +import org.springframework.ai.chat.messages.AssistantMessage; import reactor.core.publisher.Flux; import software.amazon.awssdk.regions.Region; @@ -30,10 +31,10 @@ import org.springframework.ai.bedrock.llama2.BedrockLlama2ChatClient; import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi.Llama2ChatModel; import org.springframework.ai.chat.Generation; -import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.SystemPromptTemplate; -import org.springframework.ai.prompt.messages.Message; -import org.springframework.ai.prompt.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -70,8 +71,8 @@ public class BedrockLlama2ChatAutoConfigurationIT { public void chatCompletion() { contextRunner.run(context -> { BedrockLlama2ChatClient llama2ChatClient = context.getBean(BedrockLlama2ChatClient.class); - ChatResponse response = llama2ChatClient.generate(new Prompt(List.of(userMessage, systemMessage))); - assertThat(response.getGeneration().getContent()).contains("Blackbeard"); + ChatResponse response = llama2ChatClient.call(new Prompt(List.of(userMessage, systemMessage))); + assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); }); } @@ -88,9 +89,10 @@ public void chatCompletionStreaming() { assertThat(responses.size()).isGreaterThan(2); String stitchedResponseContent = responses.stream() - .map(ChatResponse::getGenerations) + .map(ChatResponse::getResults) .flatMap(List::stream) - .map(Generation::getContent) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) .collect(Collectors.joining()); assertThat(stitchedResponseContent).contains("Blackbeard"); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfigurationIT.java index 835217e60be..44e7ef97503 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfigurationIT.java @@ -23,6 +23,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.springframework.ai.chat.ChatResponse; +import org.springframework.ai.chat.messages.AssistantMessage; import reactor.core.publisher.Flux; import software.amazon.awssdk.regions.Region; @@ -30,10 +31,10 @@ import org.springframework.ai.bedrock.titan.BedrockTitanChatClient; import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatModel; import org.springframework.ai.chat.Generation; -import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.SystemPromptTemplate; -import org.springframework.ai.prompt.messages.Message; -import org.springframework.ai.prompt.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -70,8 +71,8 @@ public class BedrockTitanChatAutoConfigurationIT { public void chatCompletion() { contextRunner.run(context -> { BedrockTitanChatClient chatClient = context.getBean(BedrockTitanChatClient.class); - ChatResponse response = chatClient.generate(new Prompt(List.of(userMessage, systemMessage))); - assertThat(response.getGeneration().getContent()).contains("Blackbeard"); + ChatResponse response = chatClient.call(new Prompt(List.of(userMessage, systemMessage))); + assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); }); } @@ -87,9 +88,10 @@ public void chatCompletionStreaming() { assertThat(responses.size()).isGreaterThan(1); String stitchedResponseContent = responses.stream() - .map(ChatResponse::getGenerations) + .map(ChatResponse::getResults) .flatMap(List::stream) - .map(Generation::getContent) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) .collect(Collectors.joining()); assertThat(stitchedResponseContent).contains("Blackbeard"); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfigurationIT.java index 2c0a7f1a1ee..0679da52dca 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfigurationIT.java @@ -23,11 +23,12 @@ import org.junit.jupiter.api.Test; import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.Generation; +import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.ollama.OllamaChatClient; -import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.SystemPromptTemplate; -import org.springframework.ai.prompt.messages.Message; -import org.springframework.ai.prompt.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.testcontainers.containers.GenericContainer; @@ -61,7 +62,7 @@ public class OllamaAutoConfigurationIT { @BeforeAll public static void beforeAll() throws IOException, InterruptedException { - logger.info("Start pulling the '" + MODEL_NAME + " ' model ... would take several minutes ..."); + logger.info("Start pulling the '" + MODEL_NAME + " ' generative ... would take several minutes ..."); ollamaContainer.execInContainer("ollama", "pull", MODEL_NAME); logger.info(MODEL_NAME + " pulling competed!"); @@ -88,8 +89,8 @@ public static void beforeAll() throws IOException, InterruptedException { public void chatCompletion() { contextRunner.run(context -> { OllamaChatClient chatClient = context.getBean(OllamaChatClient.class); - ChatResponse response = chatClient.generate(new Prompt(List.of(userMessage, systemMessage))); - assertThat(response.getGeneration().getContent()).contains("Blackbeard"); + ChatResponse response = chatClient.call(new Prompt(List.of(userMessage, systemMessage))); + assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); }); } @@ -105,9 +106,10 @@ public void chatCompletionStreaming() { assertThat(responses.size()).isGreaterThan(1); String stitchedResponseContent = responses.stream() - .map(ChatResponse::getGenerations) + .map(ChatResponse::getResults) .flatMap(List::stream) - .map(Generation::getContent) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) .collect(Collectors.joining()); assertThat(stitchedResponseContent).contains("Blackbeard"); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfigurationTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfigurationTests.java index 9ad9be0096e..d7465fc9828 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfigurationTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfigurationTests.java @@ -17,7 +17,7 @@ package org.springframework.ai.autoconfigure.ollama; import org.junit.jupiter.api.Test; -import org.springframework.ai.ollama.OllamaChatClient; + import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationIT.java index 51a4a60d9dc..6cae39e9700 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationIT.java @@ -53,7 +53,7 @@ public class OllamaEmbeddingAutoConfigurationIT { @BeforeAll public static void beforeAll() throws IOException, InterruptedException { - logger.info("Start pulling the '" + MODEL_NAME + " ' model ... would take several minutes ..."); + logger.info("Start pulling the '" + MODEL_NAME + " ' generative ... would take several minutes ..."); ollamaContainer.execInContainer("ollama", "pull", MODEL_NAME); logger.info(MODEL_NAME + " pulling competed!"); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationTests.java index e9976f11cf5..63d8287dc4e 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationTests.java @@ -17,7 +17,7 @@ package org.springframework.ai.autoconfigure.ollama; import org.junit.jupiter.api.Test; -import org.springframework.ai.ollama.OllamaEmbeddingClient; + import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfigurationIT.java index 17e8a0c5c71..748bc822a5d 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfigurationIT.java @@ -23,14 +23,14 @@ import org.apache.commons.logging.LogFactory; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; import reactor.core.publisher.Flux; import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.openai.OpenAiChatClient; import org.springframework.ai.openai.OpenAiEmbeddingClient; -import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.messages.UserMessage; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -49,7 +49,7 @@ public class OpenAiAutoConfigurationIT { void generate() { contextRunner.run(context -> { OpenAiChatClient client = context.getBean(OpenAiChatClient.class); - String response = client.generate("Hello"); + String response = client.call("Hello"); assertThat(response).isNotEmpty(); logger.info("Response: " + response); }); @@ -60,9 +60,8 @@ void generateStreaming() { contextRunner.run(context -> { OpenAiChatClient client = context.getBean(OpenAiChatClient.class); Flux responseFlux = client.generateStream(new Prompt(new UserMessage("Hello"))); - String response = responseFlux.collectList().block().stream().map(chatResponse -> { - return chatResponse.getGenerations().get(0).getContent(); + return chatResponse.getResults().get(0).getOutput().getContent(); }).collect(Collectors.joining()); assertThat(response).isNotEmpty(); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/transformers/TransformersEmbeddingClientAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/transformers/TransformersEmbeddingClientAutoConfigurationIT.java index 67121339cf4..8b497830a40 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/transformers/TransformersEmbeddingClientAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/transformers/TransformersEmbeddingClientAutoConfigurationIT.java @@ -45,8 +45,8 @@ public void embedding() { contextRunner.run(context -> { var properties = context.getBean(TransformersEmbeddingClientProperties.class); assertThat(properties.getCache().isEnabled()).isTrue(); - assertThat(properties.getCache().getDirectory()) - .isEqualTo(new File(System.getProperty("java.io.tmpdir"), "spring-ai-onnx-model").getAbsolutePath()); + assertThat(properties.getCache().getDirectory()).isEqualTo( + new File(System.getProperty("java.io.tmpdir"), "spring-ai-onnx-generative").getAbsolutePath()); EmbeddingClient embeddingClient = context.getBean(EmbeddingClient.class); assertThat(embeddingClient).isInstanceOf(TransformersEmbeddingClient.class); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/VertexAiAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/VertexAiAutoConfigurationIT.java index e650a68672b..429e4d385fc 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/VertexAiAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/VertexAiAutoConfigurationIT.java @@ -48,7 +48,7 @@ void generate() { contextRunner.run(context -> { VertexAiChatClient client = context.getBean(VertexAiChatClient.class); - String response = client.generate("Hello"); + String response = client.call("Hello"); assertThat(response).isNotEmpty(); logger.info("Response: " + response); diff --git a/spring-ai-test/src/main/java/org/springframework/ai/evaluation/BasicEvaluationTest.java b/spring-ai-test/src/main/java/org/springframework/ai/evaluation/BasicEvaluationTest.java index b8b8ca23c6b..f7bf74db4b0 100644 --- a/spring-ai-test/src/main/java/org/springframework/ai/evaluation/BasicEvaluationTest.java +++ b/spring-ai-test/src/main/java/org/springframework/ai/evaluation/BasicEvaluationTest.java @@ -20,10 +20,10 @@ import org.slf4j.LoggerFactory; import org.springframework.ai.chat.ChatClient; import org.springframework.ai.chat.ChatResponse; -import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.PromptTemplate; -import org.springframework.ai.prompt.messages.Message; -import org.springframework.ai.prompt.messages.SystemMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.core.io.Resource; @@ -55,7 +55,7 @@ public class BasicEvaluationTest { protected void evaluateQuestionAndAnswer(String question, ChatResponse response, boolean factBased) { assertThat(response).isNotNull(); - String answer = response.getGeneration().getContent(); + String answer = response.getResult().getOutput().getContent(); logger.info("Question: " + question); logger.info("Answer:" + answer); PromptTemplate userPromptTemplate = new PromptTemplate(userEvaluatorResource, @@ -69,12 +69,12 @@ protected void evaluateQuestionAndAnswer(String question, ChatResponse response, } Message userMessage = userPromptTemplate.createMessage(); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - String yesOrNo = openAiChatClient.generate(prompt).getGeneration().getContent(); + String yesOrNo = openAiChatClient.call(prompt).getResult().getOutput().getContent(); logger.info("Is Answer related to question: " + yesOrNo); if (yesOrNo.equalsIgnoreCase("no")) { SystemMessage notRelatedSystemMessage = new SystemMessage(qaEvaluatorNotRelatedResource); prompt = new Prompt(List.of(userMessage, notRelatedSystemMessage)); - String reasonForFailure = openAiChatClient.generate(prompt).getGeneration().getContent(); + String reasonForFailure = openAiChatClient.call(prompt).getResult().getOutput().getContent(); fail(reasonForFailure); } else {