From a0f71f7c9bc87dba73234cac41c1cc040946af66 Mon Sep 17 00:00:00 2001 From: Mark Pollack Date: Mon, 15 Jan 2024 00:06:24 -0500 Subject: [PATCH 01/11] This is a squahed commit of the 'options-ftw' comment. Short story is this PR provides * An abstract API for AI model clients * Providing portable client request options while still allowing vendor specific options when required. Implemented only for StabilityAI/OpenAI ImageClient Open TODOs are to * implement the options design pattern across the code base The following is the commit history of squashed commits as is: Step 1 - Implementing abstract generative API === Refactoring === * Create generative package to place all abstract API components of a Generative API. This is the "core domain" in DDD parlance * Temp Renamed ChoiceMetadata to GenerationChoiceMetadata to better identify that it is part of the Respoinse. Note, there is to be a list of Generation instances in the response. ==== Analysis ==== The currently named GeneratioMetadata contains RateLimit, Usage. What metadata applies to the "top" level response (e.g ChatResponse) * Usage * RateLimit (derived from response headers) * PromptMetadata What metadata applies to eachh individual generation * GenerationChoiceMetadata (now GenerationMetadata) === Refactoring === * ChatResponse should have one single class "ResponseMetadata" Now it has two 'GenerationMetadata' and 'PromptMetadata'. ** Rename GenerationMetadata to ChatResponseMetadata. ** Move ChatRespone field PromptMetadata into ChatRespoinseMetadata. ** Remove getPromptMetadata() and withPrompotMetadata into ChatResponseMetadata * Add PromptMetadata to AzureOpenAiGenerationMetadata.from() method * Rename AzureOpenAiGenerationMetadata to AzureOpenChatResponseMetadata * Rename OpenAiGenerationMetadata to OpenAiChatResponseMetadata * Add GenerativePrompt public interface GenerativePrompt { T getInstructions(); // required input Options getOptions(); } * Add GenerativeResponse === TODOs === * The RateLimit is only implemented for OpenAI, not in Azure OpenAI, is it available? Needs research, ask John * The PromptMetdata is only implemented for Azure OpenAI, not OpenAI, is it available? NO Step 2 - Implementing abstract generative api === Refactoring === * Refactor ChatResponseMetadata extends ResponseMetadata. ResponseMetadata is just a marker interface. * Refactor Prompt class. ow is defined as Prompt implements GenerativePrompt> { } NOTE: This should change in future to either ChatPrompt in the prompt package or better, in the 'chat' package. May consider putting in 'chat' package and keeping generic Prompt in the name for stylistic purposes. Perhaps just takes getting used to. * Add public interface GenerativeClient, TRes extends GenerativeResponse> { TReq generate(TReq prompt); } * ChatClient refactor, now defined as public interface ChatClient extends GenerativeClient { // omitted. } * Rename GenerationChoiceMetadata to GenerationMetadata Step 3 - Implementing abstract generative api * Add GenerativeGeneration interface public interface GenerativeGeneration { T getOutput(); GenerationMetadata getGenerationMetadata(); } * Refactor class generation to be public class Generation implements GenerativeGeneration { private AssistantMessage assistantMessage; private GenerationMetadata generationMetadata; } * Rename GenerationMeadata to ChatGenerationMetadata * Add marker interface public interface GenerationMetadata {} * Move metadata package under chat package for now as it is very specific to chat use cases. * Add JAvadocs to generative and chat package. == TODO == Need more work on 'portable' ResponseMetadata and GenerationMetadata Maybe in AbstactMessage protected String content; public String getContent() { return this.content; } At the moment changing getContent to getText() would be more legible since it is returning 'text' as a datatype as compared to 'content' which sounds generic and 'output' is already generic. Maybe in the future we will introduce generics into the Message class, in that case getContent is better, the content returned could be an image insteady of text. NEed to see what multi-model responses would be like. Keep getContent for now. Getting Build error on tests java.io.FileNotFoundException: class path resource [embedding/embedding-generative-dimensions.properties] cannot be opened because it does not exist [in thread "main"] Fix tests that broke due to bad refactorings Step 4 - Implementing abstract generative API Rename generative methods and classes to more generic model ones * Change package name from 'generative' to 'model' * Change class names that used prefix 'Generative' to 'Model' Multiple generative-related methods and classes have been renamed to more generic 'model' equivalents. This includes changing 'GenerativeClient' to 'ModelCall', and generative methods like 'generate' were renamed to 'call'. This was done in alignment to a broader model function definition - providing flexibility for non-generative model implementations. add text->image from openai Rename ModelCall to ModelClient and getResultMetadata() to getMetadata() Add Text to Image generation for OpenAI and Stability AI * ImagePrompt is the portable prompt that takes a string as input * ImageOptions contains the portable options between the two providers * ImageGenerationMetadata is a marker interface, as no metadata is shared between the two providers * OpenAiImageOptions and StabilityAiImageOptions are provider specific options * Implement builders for the two provider ImageOptions implementations * OpenAiImageGenerationMetadata has revisedPrompt field * StabilityAiImageGenerationMetadata has finishReason and seed fields * OpenAiImageClient merges object creating time options with runtime options * Add provider specific StabilityAiImagePrompt implements ModelRequest> * StabilityAiImageMessage has text and weight fields * StabilityAiImageClient implements ImageResponse call(ImagePrompt prompt) and add providers specific ImageResponse call(StabilityAiImagePrompt stabilityAiImagePrompt) * StabilityAiImageClient has field `StabilityAiImageOptions options;` * Add basic tests for StabilityAiApi and OpenAiImageApi TODOs * Investigate using json to copy options vs. hand coded * Use OpenAiImageOptions in OpenAiImageClient * Use of OpenAiImageOptions and StabilityAiImageOptions in @ConfigurationProperties Implement StabilityAiImageClient that supports portable and StabilityAI specifc prompt * Add ModelOptionsUtils * Add Tests. Comment in writeToFile to store image as png file * StabilityAiProperties contains field `StabilityAiImageOptions options` * Add StabilityAiImageAutoConfiguration * StabilityAiImageClient applies runtime prompt options on top of options created at instantiation time. update autoconfig resource file to include StabilityAI autoconf Refactor StabilityAI and OpenAI code to use same ImagePrompt * Simplify StabilityAI and OpenAI implementations to use a common shared ImagePrompt based on a similar 'Message' data type style used in ChatClient implementations * Change StabilityAiImageOptions `seed` property type from `Integer` to `Long` and update seed type to Long TODOs ===== Image.java should have * getBytes() * An enum that indicates the type of payload received, a url or a base64 or byte array. Convert from url/base64 to byte array. ** No support yet in stability AI for returing directly a png. AcceptHeader media/png ** need also an enum to indiate if the byte array or base64 is png,gif,jpg, etc. add nested config property annotation --- .../azure/openai/AzureOpenAiChatClient.java | 34 +-- ...a => AzureOpenAiChatResponseMetadata.java} | 29 +- .../openai/metadata/AzureOpenAiUsage.java | 2 +- .../azure/openai/AzureOpenAiChatClientIT.java | 36 +-- .../AzureOpenAiChatClientMetadataTests.java | 38 +-- .../ai/bedrock/BedrockUsage.java | 2 +- .../ai/bedrock/MessageToPromptConverter.java | 4 +- .../anthropic/BedrockAnthropicChatClient.java | 16 +- .../cohere/BedrockCohereChatClient.java | 14 +- .../llama2/BedrockLlama2ChatClient.java | 20 +- .../bedrock/titan/BedrockTitanChatClient.java | 17 +- .../BedrockAnthropicChatClientIT.java | 32 +-- .../cohere/BedrockCohereChatClientIT.java | 32 +-- .../llama2/BedrockLlama2ChatClientIT.java | 32 +-- .../titan/BedrockTitanChatClientIT.java | 32 +-- .../ai/huggingface/HuggingfaceChatClient.java | 4 +- .../ai/huggingface/client/ClientIT.java | 12 +- models/spring-ai-ollama/pom.xml | 6 + .../ai/ollama/OllamaChatClient.java | 66 ++++- .../ai/ollama/api/OllamaOptions.java | 3 +- .../ai/ollama/OllamaChatClientIT.java | 44 +-- .../ai/ollama/OllamaEmbeddingClientIT.java | 2 +- .../ai/ollama/api/OllamaApiIT.java | 2 +- ...ests.java => OllamaModelOptionsTests.java} | 2 +- .../ai/openai/OpenAiChatClient.java | 22 +- .../ai/openai/OpenAiImageClient.java | 126 +++++++++ .../ai/openai/api/OpenAiApi.java | 2 +- .../ai/openai/api/OpenAiImageApi.java | 250 ++++++++++++++++++ .../ai/openai/api/OpenAiImageOptions.java | 13 + .../openai/api/OpenAiImageOptionsBuilder.java | 59 +++++ .../ai/openai/api/OpenAiImageOptionsImpl.java | 93 +++++++ ...a.java => OpenAiChatResponseMetadata.java} | 28 +- .../OpenAiImageGenerationMetadata.java | 38 +++ .../metadata/OpenAiImageResponseMetadata.java | 46 ++++ .../ai/openai/metadata/OpenAiRateLimit.java | 2 +- .../ai/openai/metadata/OpenAiUsage.java | 2 +- .../OpenAiResponseHeaderExtractor.java | 2 +- .../ai/openai/OpenAiTestConfiguration.java | 17 +- .../ai/openai/acme/AcmeIT.java | 12 +- .../ai/openai/chat/OpenAiChatClientIT.java | 38 +-- ...tClientWithChatResponseMetadataTests.java} | 34 ++- .../ai/openai/embedding/EmbeddingIT.java | 2 +- .../ai/openai/image/OpenAiImageClientIT.java | 25 ++ ...eClientWithImageResponseMetadataTests.java | 142 ++++++++++ .../ai/openai/testutils/AbstractIT.java | 18 +- .../transformer/MetadataTransformerIT.java | 2 +- models/spring-ai-stabilityai/pom.xml | 60 +++++ .../stabilityai/StabilityAiImageClient.java | 158 +++++++++++ .../StabilityAiImageGenerationMetadata.java | 45 ++++ .../ai/stabilityai/api/StabilityAiApi.java | 212 +++++++++++++++ .../api/StabilityAiImageOptions.java | 23 ++ .../api/StabilityAiImageOptionsBuilder.java | 79 ++++++ .../api/StabilityAiImageOptionsImpl.java | 140 ++++++++++ .../ai/stabilityai/StabilityAiApiIT.java | 64 +++++ .../stabilityai/StabilityAiImageClientIT.java | 50 ++++ .../StabilityAiImageTestConfiguration.java | 30 +++ .../ai/transformers/ResourceCacheService.java | 2 +- .../TransformersEmbeddingClient.java | 8 +- .../ai/transformers/samples/ONNXSample.java | 2 +- .../ai/vertex/VertexAiChatClient.java | 10 +- .../ai/vertex/api/VertexAiApi.java | 4 +- .../ai/vertex/api/VertexAiApiTests.java | 8 +- .../VertexAiChatGenerationClientIT.java | 26 +- pom.xml | 5 +- .../springframework/ai/chat/ChatClient.java | 13 +- .../springframework/ai/chat/ChatOptions.java | 24 ++ .../ai/chat/ChatOptionsBuilder.java | 73 +++++ .../springframework/ai/chat/ChatResponse.java | 57 ++-- .../springframework/ai/chat/Generation.java | 53 ++-- .../ai/chat/StreamingChatClient.java | 2 +- .../messages/AbstractMessage.java | 4 +- .../messages/AssistantMessage.java | 8 +- .../messages/ChatMessage.java | 2 +- .../messages/FunctionMessage.java | 2 +- .../ai/{prompt => chat}/messages/Message.java | 2 +- .../messages/MessageType.java | 2 +- .../messages/SystemMessage.java | 7 +- .../messages/UserMessage.java | 4 +- .../metadata/AbstractRateLimit.java | 2 +- .../ai/{ => chat}/metadata/AbstractUsage.java | 2 +- .../metadata/ChatGenerationMetadata.java} | 19 +- .../metadata/ChatResponseMetadata.java} | 12 +- .../{ => chat}/metadata/PromptMetadata.java | 2 +- .../ai/{ => chat}/metadata/RateLimit.java | 2 +- .../ai/{ => chat}/metadata/Usage.java | 2 +- .../springframework/ai/chat/package-info.java | 14 + .../prompt/AssistantPromptTemplate.java | 4 +- .../{ => chat}/prompt/ChatPromptTemplate.java | 4 +- .../prompt/FunctionPromptTemplate.java | 2 +- .../ai/{ => chat}/prompt/Prompt.java | 43 ++- .../ai/{ => chat}/prompt/PromptTemplate.java | 8 +- .../prompt/PromptTemplateActions.java | 2 +- .../prompt/PromptTemplateChatActions.java | 4 +- .../prompt/PromptTemplateMessageActions.java | 4 +- .../prompt/PromptTemplateStringActions.java | 2 +- .../prompt/SystemPromptTemplate.java | 6 +- .../ai/{ => chat}/prompt/TemplateFormat.java | 2 +- .../ai/document/DefaultContentFormatter.java | 7 +- .../ai/embedding/EmbeddingClient.java | 3 +- .../ai/embedding/EmbeddingUtil.java | 8 +- .../org/springframework/ai/image/Image.java | 57 ++++ .../springframework/ai/image/ImageClient.java | 10 + .../ai/image/ImageGeneration.java | 35 +++ .../ai/image/ImageGenerationMetadata.java | 7 + .../ai/image/ImageMessage.java | 47 ++++ .../ai/image/ImageOptions.java | 17 ++ .../ai/image/ImageOptionsBuilder.java | 103 ++++++++ .../springframework/ai/image/ImagePrompt.java | 65 +++++ .../ai/image/ImageResponse.java | 60 +++++ .../ai/image/ImageResponseMetadata.java | 14 + .../ai/image/NewImageClient.java | 12 + .../springframework/ai/model/ModelClient.java | 26 ++ .../ai/model/ModelOptions.java | 15 ++ .../ai/model/ModelOptionsUtils.java | 72 +++++ .../ai/model/ModelRequest.java | 28 ++ .../ai/model/ModelResponse.java | 38 +++ .../springframework/ai/model/ModelResult.java | 27 ++ .../ai/model/ResponseMetadata.java | 15 ++ .../ai/model/ResultMetadata.java | 15 ++ .../ai/model/package-info.java | 11 + .../ai/parser/FormatProvider.java | 4 +- .../transformer/KeywordMetadataEnricher.java | 8 +- .../transformer/SummaryMetadataEnricher.java | 10 +- .../ai/vectorstore/filter/Filter.java | 11 +- .../embedding-model-dimensions.properties | 2 +- .../ai/chat/ChatClientTests.java | 45 ++-- .../ai/embedding/EmbeddingUtilTests.java | 1 + .../ai/metadata/PromptMetadataTests.java | 3 +- .../ai/metadata/UsageTests.java | 1 + .../ai/prompt/PromptTemplateTest.java | 20 +- .../ai/prompt/PromptTests.java | 7 +- .../ROOT/pages/api/embeddings/onnx.adoc | 2 +- spring-ai-spring-boot-autoconfigure/pom.xml | 8 + .../openai/AzureOpenAiChatProperties.java | 10 +- .../AzureOpenAiEmbeddingProperties.java | 2 +- .../BedrockAnthropicChatProperties.java | 25 +- .../cohere/BedrockCohereChatProperties.java | 26 +- .../BedrockCohereEmbeddingProperties.java | 3 +- .../llama2/BedrockLlama2ChatProperties.java | 18 +- .../titan/BedrockTitanChatProperties.java | 12 +- .../BedrockTitanEmbeddingProperties.java | 2 +- .../ollama/OllamaChatProperties.java | 8 +- .../ollama/OllamaEmbeddingProperties.java | 8 +- .../StabilityAiImageAutoConfiguration.java | 33 +++ .../stabilityai/StabilityAiProperties.java | 44 +++ ...TransformersEmbeddingClientProperties.java | 6 +- .../vertexai/VertexAiChatProperties.java | 20 +- .../vertexai/VertexAiEmbeddingProperties.java | 2 +- ...ot.autoconfigure.AutoConfiguration.imports | 1 + .../azure/AzureOpenAiAutoConfigurationIT.java | 22 +- ...eOpenAiAutoConfigurationPropertyTests.java | 5 +- ...drockAnthropicChatAutoConfigurationIT.java | 22 +- .../BedrockCohereChatAutoConfigurationIT.java | 22 +- .../BedrockLlama2ChatAutoConfigurationIT.java | 22 +- .../BedrockTitanChatAutoConfigurationIT.java | 22 +- ...rockTitanEmbeddingAutoConfigurationIT.java | 12 +- .../ollama/OllamaAutoConfigurationIT.java | 22 +- .../ollama/OllamaAutoConfigurationTests.java | 2 +- .../OllamaEmbeddingAutoConfigurationIT.java | 5 +- ...OllamaEmbeddingAutoConfigurationTests.java | 3 +- .../openai/OpenAiAutoConfigurationIT.java | 2 +- ...ersEmbeddingClientAutoConfigurationIT.java | 4 +- .../vertexai/VertexAiAutoConfigurationIT.java | 6 +- .../ai/evaluation/BasicEvaluationTest.java | 14 +- 164 files changed, 3341 insertions(+), 647 deletions(-) rename models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/{AzureOpenAiGenerationMetadata.java => AzureOpenAiChatResponseMetadata.java} (62%) rename models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/{OllamaOptionsTests.java => OllamaModelOptionsTests.java} (97%) create mode 100644 models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageClient.java create mode 100644 models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageApi.java create mode 100644 models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageOptions.java create mode 100644 models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageOptionsBuilder.java create mode 100644 models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageOptionsImpl.java rename models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/{OpenAiGenerationMetadata.java => OpenAiChatResponseMetadata.java} (66%) create mode 100644 models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiImageGenerationMetadata.java create mode 100644 models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiImageResponseMetadata.java rename models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/{OpenAiChatClientWithGenerationMetadataTests.java => OpenAiChatClientWithChatResponseMetadataTests.java} (82%) create mode 100644 models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageClientIT.java create mode 100644 models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageClientWithImageResponseMetadataTests.java create mode 100644 models/spring-ai-stabilityai/pom.xml create mode 100644 models/spring-ai-stabilityai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageClient.java create mode 100644 models/spring-ai-stabilityai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageGenerationMetadata.java create mode 100644 models/spring-ai-stabilityai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiApi.java create mode 100644 models/spring-ai-stabilityai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiImageOptions.java create mode 100644 models/spring-ai-stabilityai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiImageOptionsBuilder.java create mode 100644 models/spring-ai-stabilityai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiImageOptionsImpl.java create mode 100644 models/spring-ai-stabilityai/src/test/java/org/springframework/ai/stabilityai/StabilityAiApiIT.java create mode 100644 models/spring-ai-stabilityai/src/test/java/org/springframework/ai/stabilityai/StabilityAiImageClientIT.java create mode 100644 models/spring-ai-stabilityai/src/test/java/org/springframework/ai/stabilityai/StabilityAiImageTestConfiguration.java create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/chat/ChatOptions.java create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/chat/ChatOptionsBuilder.java rename spring-ai-core/src/main/java/org/springframework/ai/{prompt => chat}/messages/AbstractMessage.java (97%) rename spring-ai-core/src/main/java/org/springframework/ai/{prompt => chat}/messages/AssistantMessage.java (82%) rename spring-ai-core/src/main/java/org/springframework/ai/{prompt => chat}/messages/ChatMessage.java (96%) rename spring-ai-core/src/main/java/org/springframework/ai/{prompt => chat}/messages/FunctionMessage.java (95%) rename spring-ai-core/src/main/java/org/springframework/ai/{prompt => chat}/messages/Message.java (94%) rename spring-ai-core/src/main/java/org/springframework/ai/{prompt => chat}/messages/MessageType.java (95%) rename spring-ai-core/src/main/java/org/springframework/ai/{prompt => chat}/messages/SystemMessage.java (88%) rename spring-ai-core/src/main/java/org/springframework/ai/{prompt => chat}/messages/UserMessage.java (93%) rename spring-ai-core/src/main/java/org/springframework/ai/{ => chat}/metadata/AbstractRateLimit.java (96%) rename spring-ai-core/src/main/java/org/springframework/ai/{ => chat}/metadata/AbstractUsage.java (95%) rename spring-ai-core/src/main/java/org/springframework/ai/{metadata/ChoiceMetadata.java => chat/metadata/ChatGenerationMetadata.java} (73%) rename spring-ai-core/src/main/java/org/springframework/ai/{metadata/GenerationMetadata.java => chat/metadata/ChatResponseMetadata.java} (79%) rename spring-ai-core/src/main/java/org/springframework/ai/{ => chat}/metadata/PromptMetadata.java (98%) rename spring-ai-core/src/main/java/org/springframework/ai/{ => chat}/metadata/RateLimit.java (98%) rename spring-ai-core/src/main/java/org/springframework/ai/{ => chat}/metadata/Usage.java (97%) create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/chat/package-info.java rename spring-ai-core/src/main/java/org/springframework/ai/{ => chat}/prompt/AssistantPromptTemplate.java (90%) rename spring-ai-core/src/main/java/org/springframework/ai/{ => chat}/prompt/ChatPromptTemplate.java (95%) rename spring-ai-core/src/main/java/org/springframework/ai/{ => chat}/prompt/FunctionPromptTemplate.java (94%) rename spring-ai-core/src/main/java/org/springframework/ai/{ => chat}/prompt/Prompt.java (53%) rename spring-ai-core/src/main/java/org/springframework/ai/{ => chat}/prompt/PromptTemplate.java (96%) rename spring-ai-core/src/main/java/org/springframework/ai/{ => chat}/prompt/PromptTemplateActions.java (94%) rename spring-ai-core/src/main/java/org/springframework/ai/{ => chat}/prompt/PromptTemplateChatActions.java (66%) rename spring-ai-core/src/main/java/org/springframework/ai/{ => chat}/prompt/PromptTemplateMessageActions.java (61%) rename spring-ai-core/src/main/java/org/springframework/ai/{ => chat}/prompt/PromptTemplateStringActions.java (75%) rename spring-ai-core/src/main/java/org/springframework/ai/{ => chat}/prompt/SystemPromptTemplate.java (89%) rename spring-ai-core/src/main/java/org/springframework/ai/{ => chat}/prompt/TemplateFormat.java (96%) create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/image/Image.java create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/image/ImageClient.java create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/image/ImageGeneration.java create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/image/ImageGenerationMetadata.java create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/image/ImageMessage.java create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/image/ImageOptions.java create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/image/ImageOptionsBuilder.java create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/image/ImagePrompt.java create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/image/ImageResponse.java create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/image/ImageResponseMetadata.java create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/image/NewImageClient.java create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/model/ModelClient.java create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptions.java create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/model/ModelRequest.java create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/model/ModelResponse.java create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/model/ModelResult.java create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/model/ResponseMetadata.java create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/model/ResultMetadata.java create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/model/package-info.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiImageAutoConfiguration.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiProperties.java diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatClient.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatClient.java index 72253696b96..5decb221b0b 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatClient.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatClient.java @@ -32,18 +32,18 @@ import com.azure.core.util.IterableStream; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import reactor.core.publisher.Flux; -import org.springframework.ai.azure.openai.metadata.AzureOpenAiGenerationMetadata; +import org.springframework.ai.azure.openai.metadata.AzureOpenAiChatResponseMetadata; import org.springframework.ai.chat.ChatClient; import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.Generation; import org.springframework.ai.chat.StreamingChatClient; -import org.springframework.ai.metadata.ChoiceMetadata; -import org.springframework.ai.metadata.PromptMetadata; -import org.springframework.ai.metadata.PromptMetadata.PromptFilterMetadata; -import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.messages.Message; +import org.springframework.ai.chat.metadata.PromptMetadata; +import org.springframework.ai.chat.metadata.PromptMetadata.PromptFilterMetadata; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.messages.Message; import org.springframework.util.Assert; /** @@ -134,7 +134,7 @@ public AzureOpenAiChatClient withMaxTokens(Integer maxTokens) { } @Override - public String generate(String text) { + public String call(String text) { ChatRequestMessage azureChatMessage = new ChatRequestUserMessage(text); @@ -160,7 +160,7 @@ public String generate(String text) { } @Override - public ChatResponse generate(Prompt prompt) { + public ChatResponse call(Prompt prompt) { ChatCompletionsOptions options = toAzureChatCompletionsOptions(prompt); options.setStream(false); @@ -174,11 +174,12 @@ public ChatResponse generate(Prompt prompt) { List generations = chatCompletions.getChoices() .stream() .map(choice -> new Generation(choice.getMessage().getContent()) - .withChoiceMetadata(generateChoiceMetadata(choice))) + .withGenerationMetadata(generateChoiceMetadata(choice))) .toList(); - return new ChatResponse(generations, AzureOpenAiGenerationMetadata.from(chatCompletions)) - .withPromptMetadata(generatePromptMetadata(chatCompletions)); + PromptMetadata promptFilterMetadata = generatePromptMetadata(chatCompletions); + return new ChatResponse(generations, + AzureOpenAiChatResponseMetadata.from(chatCompletions, promptFilterMetadata)); } @Override @@ -199,14 +200,17 @@ public Flux generateStream(Prompt prompt) { .flatMap(List::stream) .map(choice -> { var content = (choice.getDelta() != null) ? choice.getDelta().getContent() : null; - var generation = new Generation(content).withChoiceMetadata(generateChoiceMetadata(choice)); + var generation = new Generation(content).withGenerationMetadata(generateChoiceMetadata(choice)); return new ChatResponse(List.of(generation)); })); } private ChatCompletionsOptions toAzureChatCompletionsOptions(Prompt prompt) { - List azureMessages = prompt.getMessages().stream().map(this::fromSpringAiMessage).toList(); + List azureMessages = prompt.getInstructions() + .stream() + .map(this::fromSpringAiMessage) + .toList(); ChatCompletionsOptions options = new ChatCompletionsOptions(azureMessages); @@ -233,8 +237,8 @@ private ChatRequestMessage fromSpringAiMessage(Message message) { } - private ChoiceMetadata generateChoiceMetadata(ChatChoice choice) { - return ChoiceMetadata.from(String.valueOf(choice.getFinishReason()), choice.getContentFilterResults()); + private ChatGenerationMetadata generateChoiceMetadata(ChatChoice choice) { + return ChatGenerationMetadata.from(String.valueOf(choice.getFinishReason()), choice.getContentFilterResults()); } private PromptMetadata generatePromptMetadata(ChatCompletions chatCompletions) { diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiGenerationMetadata.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiChatResponseMetadata.java similarity index 62% rename from models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiGenerationMetadata.java rename to models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiChatResponseMetadata.java index 51208526da8..4472a6df2c3 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiGenerationMetadata.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiChatResponseMetadata.java @@ -18,38 +18,44 @@ import com.azure.ai.openai.models.ChatCompletions; -import org.springframework.ai.metadata.GenerationMetadata; -import org.springframework.ai.metadata.Usage; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.metadata.PromptMetadata; +import org.springframework.ai.chat.metadata.Usage; import org.springframework.util.Assert; /** - * {@link GenerationMetadata} implementation for + * {@link ChatResponseMetadata} implementation for * {@literal Microsoft Azure OpenAI Service}. * * @author John Blum - * @see org.springframework.ai.metadata.GenerationMetadata + * @see ChatResponseMetadata * @since 0.7.1 */ -public class AzureOpenAiGenerationMetadata implements GenerationMetadata { +public class AzureOpenAiChatResponseMetadata implements ChatResponseMetadata { protected static final String AI_METADATA_STRING = "{ @type: %1$s, id: %2$s, usage: %3$s, rateLimit: %4$s }"; @SuppressWarnings("all") - public static AzureOpenAiGenerationMetadata from(ChatCompletions chatCompletions) { + public static AzureOpenAiChatResponseMetadata from(ChatCompletions chatCompletions, + PromptMetadata promptFilterMetadata) { Assert.notNull(chatCompletions, "Azure OpenAI ChatCompletions must not be null"); String id = chatCompletions.getId(); AzureOpenAiUsage usage = AzureOpenAiUsage.from(chatCompletions); - AzureOpenAiGenerationMetadata generationMetadata = new AzureOpenAiGenerationMetadata(id, usage); - return generationMetadata; + AzureOpenAiChatResponseMetadata chatResponseMetadata = new AzureOpenAiChatResponseMetadata(id, usage, + promptFilterMetadata); + return chatResponseMetadata; } private final String id; private final Usage usage; - protected AzureOpenAiGenerationMetadata(String id, AzureOpenAiUsage usage) { + private final PromptMetadata promptMetadata; + + protected AzureOpenAiChatResponseMetadata(String id, AzureOpenAiUsage usage, PromptMetadata promptMetadata) { this.id = id; this.usage = usage; + this.promptMetadata = promptMetadata; } public String getId() { @@ -61,6 +67,11 @@ public Usage getUsage() { return this.usage; } + @Override + public PromptMetadata getPromptMetadata() { + return this.promptMetadata; + } + @Override public String toString() { return AI_METADATA_STRING.formatted(getClass().getTypeName(), getId(), getUsage(), getRateLimit()); diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiUsage.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiUsage.java index 73bdc47722a..b5af6fe603b 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiUsage.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiUsage.java @@ -19,7 +19,7 @@ import com.azure.ai.openai.models.ChatCompletions; import com.azure.ai.openai.models.CompletionsUsage; -import org.springframework.ai.metadata.Usage; +import org.springframework.ai.chat.metadata.Usage; import org.springframework.util.Assert; /** diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatClientIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatClientIT.java index 141fc0fd8bc..d82db85c31e 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatClientIT.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatClientIT.java @@ -14,14 +14,15 @@ import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.Generation; +import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.parser.BeanOutputParser; import org.springframework.ai.parser.ListOutputParser; import org.springframework.ai.parser.MapOutputParser; -import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.PromptTemplate; -import org.springframework.ai.prompt.SystemPromptTemplate; -import org.springframework.ai.prompt.messages.Message; -import org.springframework.ai.prompt.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; @@ -53,8 +54,8 @@ void roleTest() { UserMessage userMessage = new UserMessage("Generate the names of 5 famous pirates."); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - ChatResponse response = chatClient.generate(prompt); - assertThat(response.getGeneration().getContent()).contains("Blackbeard"); + ChatResponse response = chatClient.call(prompt); + assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); } @Test @@ -70,9 +71,9 @@ void outputParser() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "ice cream flavors", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatClient.generate(prompt).getGeneration(); + Generation generation = chatClient.call(prompt).getResult(); - List list = outputParser.parse(generation.getContent()); + List list = outputParser.parse(generation.getOutput().getContent()); assertThat(list).hasSize(5); } @@ -89,9 +90,9 @@ void mapOutputParser() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatClient.generate(prompt).getGeneration(); + Generation generation = chatClient.call(prompt).getResult(); - Map result = outputParser.parse(generation.getContent()); + Map result = outputParser.parse(generation.getOutput().getContent()); assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); } @@ -108,9 +109,9 @@ void beanOutputParser() { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatClient.generate(prompt).getGeneration(); + Generation generation = chatClient.call(prompt).getResult(); - ActorsFilms actorsFilms = outputParser.parse(generation.getContent()); + ActorsFilms actorsFilms = outputParser.parse(generation.getOutput().getContent()); assertThat(actorsFilms.actor()).isNotNull(); } @@ -129,9 +130,9 @@ void beanOutputParserRecords() { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatClient.generate(prompt).getGeneration(); + Generation generation = chatClient.call(prompt).getResult(); - ActorsFilmsRecord actorsFilms = outputParser.parse(generation.getContent()); + ActorsFilmsRecord actorsFilms = outputParser.parse(generation.getOutput().getContent()); System.out.println(actorsFilms); assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); assertThat(actorsFilms.movies()).hasSize(5); @@ -154,9 +155,10 @@ void beanStreamOutputParserRecords() { .collectList() .block() .stream() - .map(ChatResponse::getGenerations) + .map(ChatResponse::getResults) .flatMap(List::stream) - .map(Generation::getContent) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) .filter(Objects::nonNull) .collect(Collectors.joining()); diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiChatClientMetadataTests.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiChatClientMetadataTests.java index 1a69a6bed6b..1cf366efa07 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiChatClientMetadataTests.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiChatClientMetadataTests.java @@ -28,12 +28,13 @@ import org.springframework.ai.azure.openai.MockAzureOpenAiTestConfiguration; import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.Generation; -import org.springframework.ai.metadata.ChoiceMetadata; -import org.springframework.ai.metadata.GenerationMetadata; -import org.springframework.ai.metadata.PromptMetadata; -import org.springframework.ai.metadata.RateLimit; -import org.springframework.ai.metadata.Usage; -import org.springframework.ai.prompt.Prompt; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.metadata.PromptMetadata; +import org.springframework.ai.chat.metadata.RateLimit; +import org.springframework.ai.chat.metadata.Usage; +import org.springframework.ai.chat.prompt.Prompt; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; @@ -75,14 +76,15 @@ void azureOpenAiMetadataCapturedDuringGeneration() { Prompt prompt = new Prompt("Can I fly like a bird?"); - ChatResponse response = this.aiClient.generate(prompt); + ChatResponse response = this.aiClient.call(prompt); assertThat(response).isNotNull(); - Generation generation = response.getGeneration(); + Generation generation = response.getResult(); assertThat(generation).isNotNull() - .extracting(Generation::getContent) + .extracting(Generation::getOutput) + .extracting(AssistantMessage::getContent) .isEqualTo("No! You will actually land with a resounding thud. This is the way!"); assertPromptMetadata(response); @@ -92,7 +94,7 @@ void azureOpenAiMetadataCapturedDuringGeneration() { private void assertPromptMetadata(ChatResponse response) { - PromptMetadata promptMetadata = response.getPromptMetadata(); + PromptMetadata promptMetadata = response.getMetadata().getPromptMetadata(); assertThat(promptMetadata).isNotNull(); @@ -106,12 +108,12 @@ private void assertPromptMetadata(ChatResponse response) { private void assertGenerationMetadata(ChatResponse response) { - GenerationMetadata generationMetadata = response.getGenerationMetadata(); + ChatResponseMetadata chatResponseMetadata = response.getMetadata(); - assertThat(generationMetadata).isNotNull(); - assertThat(generationMetadata.getRateLimit()).isEqualTo(RateLimit.NULL); + assertThat(chatResponseMetadata).isNotNull(); + assertThat(chatResponseMetadata.getRateLimit()).isEqualTo(RateLimit.NULL); - Usage usage = generationMetadata.getUsage(); + Usage usage = chatResponseMetadata.getUsage(); assertThat(usage).isNotNull(); assertThat(usage).isNotEqualTo(Usage.NULL); @@ -122,11 +124,11 @@ private void assertGenerationMetadata(ChatResponse response) { private void assertChoiceMetadata(Generation generation) { - ChoiceMetadata choiceMetadata = generation.getChoiceMetadata(); + ChatGenerationMetadata chatGenerationMetadata = generation.getResultMetadata(); - assertThat(choiceMetadata).isNotNull(); - assertThat(choiceMetadata.getFinishReason()).isEqualTo("stop"); - assertContentFilterResults(choiceMetadata.getContentFilterMetadata()); + assertThat(chatGenerationMetadata).isNotNull(); + assertThat(chatGenerationMetadata.getFinishReason()).isEqualTo("stop"); + assertContentFilterResults(chatGenerationMetadata.getContentFilterMetadata()); } private void assertContentFilterResultsForPrompt(ContentFilterResultDetailsForPrompt contentFilterResultForPrompt, diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/BedrockUsage.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/BedrockUsage.java index 88427d9e025..9fa68242a01 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/BedrockUsage.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/BedrockUsage.java @@ -17,7 +17,7 @@ package org.springframework.ai.bedrock; import org.springframework.ai.bedrock.api.AbstractBedrockApi.AmazonBedrockInvocationMetrics; -import org.springframework.ai.metadata.Usage; +import org.springframework.ai.chat.metadata.Usage; import org.springframework.util.Assert; /** diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/MessageToPromptConverter.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/MessageToPromptConverter.java index 937819a022a..d01ecffddc5 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/MessageToPromptConverter.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/MessageToPromptConverter.java @@ -19,8 +19,8 @@ import java.util.List; import java.util.stream.Collectors; -import org.springframework.ai.prompt.messages.Message; -import org.springframework.ai.prompt.messages.MessageType; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.MessageType; /** * Converts a list of messages to a prompt for bedrock models. diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatClient.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatClient.java index 6c7dd7f39c1..a8b67702648 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatClient.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatClient.java @@ -20,6 +20,7 @@ import org.springframework.ai.chat.ChatClient; import org.springframework.ai.chat.ChatResponse; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import reactor.core.publisher.Flux; import org.springframework.ai.bedrock.MessageToPromptConverter; @@ -28,12 +29,11 @@ import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatResponse; import org.springframework.ai.chat.StreamingChatClient; import org.springframework.ai.chat.Generation; -import org.springframework.ai.metadata.ChoiceMetadata; -import org.springframework.ai.prompt.Prompt; +import org.springframework.ai.chat.prompt.Prompt; /** * Java {@link ChatClient} and {@link StreamingChatClient} for the Bedrock Anthropic chat - * model. + * generative. * * @author Christian Tzolov * @since 0.8.0 @@ -89,8 +89,8 @@ public BedrockAnthropicChatClient withAnthropicVersion(String anthropicVersion) } @Override - public ChatResponse generate(Prompt prompt) { - final String promptValue = MessageToPromptConverter.create().toPrompt(prompt.getMessages()); + public ChatResponse call(Prompt prompt) { + final String promptValue = MessageToPromptConverter.create().toPrompt(prompt.getInstructions()); AnthropicChatRequest request = AnthropicChatRequest.builder(promptValue) .withTemperature(this.temperature) @@ -109,7 +109,7 @@ public ChatResponse generate(Prompt prompt) { @Override public Flux generateStream(Prompt prompt) { - final String promptValue = MessageToPromptConverter.create().toPrompt(prompt.getMessages()); + final String promptValue = MessageToPromptConverter.create().toPrompt(prompt.getInstructions()); AnthropicChatRequest request = AnthropicChatRequest.builder(promptValue) .withTemperature(this.temperature) @@ -126,8 +126,8 @@ public Flux generateStream(Prompt prompt) { String stopReason = response.stopReason() != null ? response.stopReason() : null; var generation = new Generation(response.completion()); if (response.amazonBedrockInvocationMetrics() != null) { - generation = generation - .withChoiceMetadata(ChoiceMetadata.from(stopReason, response.amazonBedrockInvocationMetrics())); + generation = generation.withGenerationMetadata( + ChatGenerationMetadata.from(stopReason, response.amazonBedrockInvocationMetrics())); } return new ChatResponse(List.of(generation)); }); diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatClient.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatClient.java index 50012480737..a5c1e3e188c 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatClient.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatClient.java @@ -19,6 +19,7 @@ import java.util.List; import org.springframework.ai.chat.ChatResponse; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import reactor.core.publisher.Flux; import org.springframework.ai.bedrock.BedrockUsage; @@ -32,9 +33,8 @@ import org.springframework.ai.chat.ChatClient; import org.springframework.ai.chat.StreamingChatClient; import org.springframework.ai.chat.Generation; -import org.springframework.ai.metadata.ChoiceMetadata; -import org.springframework.ai.metadata.Usage; -import org.springframework.ai.prompt.Prompt; +import org.springframework.ai.chat.metadata.Usage; +import org.springframework.ai.chat.prompt.Prompt; /** * @author Christian Tzolov @@ -112,7 +112,7 @@ public BedrockCohereChatClient withTruncate(Truncate truncate) { } @Override - public ChatResponse generate(Prompt prompt) { + public ChatResponse call(Prompt prompt) { CohereChatResponse response = this.chatApi.chatCompletion(this.createRequest(prompt, false)); List generations = response.generations().stream().map(g -> { return new Generation(g.text()); @@ -127,15 +127,15 @@ public Flux generateStream(Prompt prompt) { if (g.isFinished()) { String finishReason = g.finishReason().name(); Usage usage = BedrockUsage.from(g.amazonBedrockInvocationMetrics()); - return new ChatResponse( - List.of(new Generation("").withChoiceMetadata(ChoiceMetadata.from(finishReason, usage)))); + return new ChatResponse(List + .of(new Generation("").withGenerationMetadata(ChatGenerationMetadata.from(finishReason, usage)))); } return new ChatResponse(List.of(new Generation(g.text()))); }); } private CohereChatRequest createRequest(Prompt prompt, boolean stream) { - final String promptValue = MessageToPromptConverter.create().toPrompt(prompt.getMessages()); + final String promptValue = MessageToPromptConverter.create().toPrompt(prompt.getInstructions()); return CohereChatRequest.builder(promptValue) .withTemperature(this.temperature) diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatClient.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatClient.java index e4fc8d69dbd..38e157394af 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatClient.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatClient.java @@ -28,13 +28,13 @@ import org.springframework.ai.chat.ChatClient; import org.springframework.ai.chat.StreamingChatClient; import org.springframework.ai.chat.Generation; -import org.springframework.ai.metadata.ChoiceMetadata; -import org.springframework.ai.metadata.Usage; -import org.springframework.ai.prompt.Prompt; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; +import org.springframework.ai.chat.metadata.Usage; +import org.springframework.ai.chat.prompt.Prompt; /** * Java {@link ChatClient} and {@link StreamingChatClient} for the Bedrock Llama2 chat - * model. + * generative. * * @author Christian Tzolov * @since 0.8.0 @@ -69,8 +69,8 @@ public BedrockLlama2ChatClient withMaxGenLen(Integer maxGenLen) { } @Override - public ChatResponse generate(Prompt prompt) { - final String promptValue = MessageToPromptConverter.create().toPrompt(prompt.getMessages()); + public ChatResponse call(Prompt prompt) { + final String promptValue = MessageToPromptConverter.create().toPrompt(prompt.getInstructions()); var request = Llama2ChatRequest.builder(promptValue) .withTemperature(this.temperature) @@ -80,14 +80,14 @@ public ChatResponse generate(Prompt prompt) { Llama2ChatResponse response = this.chatApi.chatCompletion(request); - return new ChatResponse(List.of(new Generation(response.generation()) - .withChoiceMetadata(ChoiceMetadata.from(response.stopReason().name(), extractUsage(response))))); + return new ChatResponse(List.of(new Generation(response.generation()).withGenerationMetadata( + ChatGenerationMetadata.from(response.stopReason().name(), extractUsage(response))))); } @Override public Flux generateStream(Prompt prompt) { - final String promptValue = MessageToPromptConverter.create().toPrompt(prompt.getMessages()); + final String promptValue = MessageToPromptConverter.create().toPrompt(prompt.getInstructions()); var request = Llama2ChatRequest.builder(promptValue) .withTemperature(this.temperature) @@ -100,7 +100,7 @@ public Flux generateStream(Prompt prompt) { return fluxResponse.map(response -> { String stopReason = response.stopReason() != null ? response.stopReason().name() : null; return new ChatResponse(List.of(new Generation(response.generation()) - .withChoiceMetadata(ChoiceMetadata.from(stopReason, extractUsage(response))))); + .withGenerationMetadata(ChatGenerationMetadata.from(stopReason, extractUsage(response))))); }); } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatClient.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatClient.java index 14a95d04aba..b2279ce836b 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatClient.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatClient.java @@ -19,6 +19,7 @@ import java.util.List; import org.springframework.ai.chat.ChatClient; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import reactor.core.publisher.Flux; import org.springframework.ai.bedrock.MessageToPromptConverter; @@ -29,9 +30,8 @@ import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.StreamingChatClient; import org.springframework.ai.chat.Generation; -import org.springframework.ai.metadata.ChoiceMetadata; -import org.springframework.ai.metadata.Usage; -import org.springframework.ai.prompt.Prompt; +import org.springframework.ai.chat.metadata.Usage; +import org.springframework.ai.chat.prompt.Prompt; /** * @author Christian Tzolov @@ -74,7 +74,7 @@ public BedrockTitanChatClient withStopSequences(List stopSequences) { } @Override - public ChatResponse generate(Prompt prompt) { + public ChatResponse call(Prompt prompt) { TitanChatResponse response = this.chatApi.chatCompletion(this.createRequest(prompt, false)); List generations = response.results().stream().map(result -> { return new Generation(result.outputText()); @@ -91,12 +91,13 @@ public Flux generateStream(Prompt prompt) { if (chunk.amazonBedrockInvocationMetrics() != null) { String completionReason = chunk.completionReason().name(); - generation = generation - .withChoiceMetadata(ChoiceMetadata.from(completionReason, chunk.amazonBedrockInvocationMetrics())); + generation = generation.withGenerationMetadata( + ChatGenerationMetadata.from(completionReason, chunk.amazonBedrockInvocationMetrics())); } else if (chunk.inputTextTokenCount() != null && chunk.totalOutputTextTokenCount() != null) { String completionReason = chunk.completionReason().name(); - generation = generation.withChoiceMetadata(ChoiceMetadata.from(completionReason, extractUsage(chunk))); + generation = generation + .withGenerationMetadata(ChatGenerationMetadata.from(completionReason, extractUsage(chunk))); } return new ChatResponse(List.of(generation)); @@ -104,7 +105,7 @@ else if (chunk.inputTextTokenCount() != null && chunk.totalOutputTextTokenCount( } private TitanChatRequest createRequest(Prompt prompt, boolean stream) { - final String promptValue = MessageToPromptConverter.create().toPrompt(prompt.getMessages()); + final String promptValue = MessageToPromptConverter.create().toPrompt(prompt.getInstructions()); return TitanChatRequest.builder(promptValue) .withTemperature(this.temperature) diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatClientIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatClientIT.java index 14ecad3f9e4..fd342892766 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatClientIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatClientIT.java @@ -9,6 +9,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.springframework.ai.chat.ChatResponse; +import org.springframework.ai.chat.messages.AssistantMessage; import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; import software.amazon.awssdk.regions.Region; @@ -17,11 +18,11 @@ import org.springframework.ai.parser.BeanOutputParser; import org.springframework.ai.parser.ListOutputParser; import org.springframework.ai.parser.MapOutputParser; -import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.PromptTemplate; -import org.springframework.ai.prompt.SystemPromptTemplate; -import org.springframework.ai.prompt.messages.Message; -import org.springframework.ai.prompt.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.SpringBootConfiguration; @@ -52,9 +53,9 @@ void roleTest() { Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - ChatResponse response = client.generate(prompt); + ChatResponse response = client.call(prompt); - assertThat(response.getGeneration().getContent()).contains("Blackbeard"); + assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); } @Test @@ -70,9 +71,9 @@ void outputParser() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "ice cream flavors.", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = this.client.generate(prompt).getGeneration(); + Generation generation = this.client.call(prompt).getResult(); - List list = outputParser.parse(generation.getContent()); + List list = outputParser.parse(generation.getOutput().getContent()); assertThat(list).hasSize(5); } @@ -88,9 +89,9 @@ void mapOutputParser() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = client.generate(prompt).getGeneration(); + Generation generation = client.call(prompt).getResult(); - Map result = outputParser.parse(generation.getContent()); + Map result = outputParser.parse(generation.getOutput().getContent()); assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); } @@ -112,9 +113,9 @@ void beanOutputParserRecords() { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = client.generate(prompt).getGeneration(); + Generation generation = client.call(prompt).getResult(); - ActorsFilmsRecord actorsFilms = outputParser.parse(generation.getContent()); + ActorsFilmsRecord actorsFilms = outputParser.parse(generation.getOutput().getContent()); assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); assertThat(actorsFilms.movies()).hasSize(5); } @@ -137,9 +138,10 @@ void beanStreamOutputParserRecords() { .collectList() .block() .stream() - .map(ChatResponse::getGenerations) + .map(ChatResponse::getResults) .flatMap(List::stream) - .map(Generation::getContent) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) .collect(Collectors.joining()); ActorsFilmsRecord actorsFilms = outputParser.parse(generationTextFromStream); diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatClientIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatClientIT.java index bb4ffc32fe5..9fc97ee2c6f 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatClientIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatClientIT.java @@ -8,6 +8,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.chat.messages.AssistantMessage; import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; import software.amazon.awssdk.regions.Region; @@ -18,11 +19,11 @@ import org.springframework.ai.parser.BeanOutputParser; import org.springframework.ai.parser.ListOutputParser; import org.springframework.ai.parser.MapOutputParser; -import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.PromptTemplate; -import org.springframework.ai.prompt.SystemPromptTemplate; -import org.springframework.ai.prompt.messages.Message; -import org.springframework.ai.prompt.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.SpringBootConfiguration; @@ -53,8 +54,8 @@ void roleTest() { SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", name, "voice", voice)); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - ChatResponse response = client.generate(prompt); - assertThat(response.getGeneration().getContent()).contains("Blackbeard"); + ChatResponse response = client.call(prompt); + assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); } @Test @@ -70,9 +71,9 @@ void outputParser() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "ice cream flavors.", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = this.client.generate(prompt).getGeneration(); + Generation generation = this.client.call(prompt).getResult(); - List list = outputParser.parse(generation.getContent()); + List list = outputParser.parse(generation.getOutput().getContent()); assertThat(list).hasSize(5); } @@ -89,9 +90,9 @@ void mapOutputParser() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = client.generate(prompt).getGeneration(); + Generation generation = client.call(prompt).getResult(); - Map result = outputParser.parse(generation.getContent()); + Map result = outputParser.parse(generation.getOutput().getContent()); assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); } @@ -112,9 +113,9 @@ void beanOutputParserRecords() { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = client.generate(prompt).getGeneration(); + Generation generation = client.call(prompt).getResult(); - ActorsFilmsRecord actorsFilms = outputParser.parse(generation.getContent()); + ActorsFilmsRecord actorsFilms = outputParser.parse(generation.getOutput().getContent()); assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); assertThat(actorsFilms.movies()).hasSize(5); } @@ -137,9 +138,10 @@ void beanStreamOutputParserRecords() { .collectList() .block() .stream() - .map(ChatResponse::getGenerations) + .map(ChatResponse::getResults) .flatMap(List::stream) - .map(Generation::getContent) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) .collect(Collectors.joining()); ActorsFilmsRecord actorsFilms = outputParser.parse(generationTextFromStream); diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatClientIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatClientIT.java index 4cbbc8d072e..b1ec0458e17 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatClientIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatClientIT.java @@ -10,6 +10,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.springframework.ai.chat.ChatResponse; +import org.springframework.ai.chat.messages.AssistantMessage; import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; import software.amazon.awssdk.regions.Region; @@ -19,11 +20,11 @@ import org.springframework.ai.parser.BeanOutputParser; import org.springframework.ai.parser.ListOutputParser; import org.springframework.ai.parser.MapOutputParser; -import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.PromptTemplate; -import org.springframework.ai.prompt.SystemPromptTemplate; -import org.springframework.ai.prompt.messages.Message; -import org.springframework.ai.prompt.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.SpringBootConfiguration; @@ -54,9 +55,9 @@ void roleTest() { Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - ChatResponse response = client.generate(prompt); + ChatResponse response = client.call(prompt); - assertThat(response.getGeneration().getContent()).contains("Blackbeard"); + assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); } @Disabled("TODO: Fix the parser instructions to return the correct format") @@ -73,9 +74,9 @@ void outputParser() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "ice cream flavors.", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = this.client.generate(prompt).getGeneration(); + Generation generation = this.client.call(prompt).getResult(); - List list = outputParser.parse(generation.getContent()); + List list = outputParser.parse(generation.getOutput().getContent()); assertThat(list).hasSize(5); } @@ -91,9 +92,9 @@ void mapOutputParser() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = client.generate(prompt).getGeneration(); + Generation generation = client.call(prompt).getResult(); - Map result = outputParser.parse(generation.getContent()); + Map result = outputParser.parse(generation.getOutput().getContent()); assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); } @@ -116,9 +117,9 @@ void beanOutputParserRecords() { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = client.generate(prompt).getGeneration(); + Generation generation = client.call(prompt).getResult(); - ActorsFilmsRecord actorsFilms = outputParser.parse(generation.getContent()); + ActorsFilmsRecord actorsFilms = outputParser.parse(generation.getOutput().getContent()); assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); assertThat(actorsFilms.movies()).hasSize(5); } @@ -142,9 +143,10 @@ void beanStreamOutputParserRecords() { .collectList() .block() .stream() - .map(ChatResponse::getGenerations) + .map(ChatResponse::getResults) .flatMap(List::stream) - .map(Generation::getContent) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) .collect(Collectors.joining()); ActorsFilmsRecord actorsFilms = outputParser.parse(generationTextFromStream); diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanChatClientIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanChatClientIT.java index 2f96bb84f23..1c07073cc0e 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanChatClientIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanChatClientIT.java @@ -10,6 +10,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.springframework.ai.chat.ChatResponse; +import org.springframework.ai.chat.messages.AssistantMessage; import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; import software.amazon.awssdk.regions.Region; @@ -19,11 +20,11 @@ import org.springframework.ai.parser.BeanOutputParser; import org.springframework.ai.parser.ListOutputParser; import org.springframework.ai.parser.MapOutputParser; -import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.PromptTemplate; -import org.springframework.ai.prompt.SystemPromptTemplate; -import org.springframework.ai.prompt.messages.Message; -import org.springframework.ai.prompt.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.SpringBootConfiguration; @@ -54,8 +55,8 @@ void roleTest() { SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", name, "voice", voice)); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - ChatResponse response = client.generate(prompt); - assertThat(response.getGeneration().getContent()).contains("Blackbeard"); + ChatResponse response = client.call(prompt); + assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); } @Disabled("TODO: Fix the parser instructions to return the correct format") @@ -72,9 +73,9 @@ void outputParser() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "ice cream flavors.", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = this.client.generate(prompt).getGeneration(); + Generation generation = this.client.call(prompt).getResult(); - List list = outputParser.parse(generation.getContent()); + List list = outputParser.parse(generation.getOutput().getContent()); assertThat(list).hasSize(5); } @@ -93,9 +94,9 @@ void mapOutputParser() { Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = client.generate(prompt).getGeneration(); + Generation generation = client.call(prompt).getResult(); - Map result = outputParser.parse(generation.getContent()); + Map result = outputParser.parse(generation.getOutput().getContent()); assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); } @@ -117,9 +118,9 @@ void beanOutputParserRecords() { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = client.generate(prompt).getGeneration(); + Generation generation = client.call(prompt).getResult(); - ActorsFilmsRecord actorsFilms = outputParser.parse(generation.getContent()); + ActorsFilmsRecord actorsFilms = outputParser.parse(generation.getOutput().getContent()); assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); assertThat(actorsFilms.movies()).hasSize(5); } @@ -143,9 +144,10 @@ void beanStreamOutputParserRecords() { .collectList() .block() .stream() - .map(ChatResponse::getGenerations) + .map(ChatResponse::getResults) .flatMap(List::stream) - .map(Generation::getContent) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) .collect(Collectors.joining()); ActorsFilmsRecord actorsFilms = outputParser.parse(generationTextFromStream); diff --git a/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceChatClient.java b/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceChatClient.java index 8b32bb80c19..a279f79372b 100644 --- a/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceChatClient.java +++ b/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceChatClient.java @@ -32,7 +32,7 @@ import org.springframework.ai.huggingface.model.GenerateParameters; import org.springframework.ai.huggingface.model.GenerateRequest; import org.springframework.ai.huggingface.model.GenerateResponse; -import org.springframework.ai.prompt.Prompt; +import org.springframework.ai.chat.prompt.Prompt; /** * An implementation of {@link ChatClient} that interfaces with HuggingFace Inference @@ -86,7 +86,7 @@ public HuggingfaceChatClient(final String apiToken, String basePath) { * @return ChatResponse containing the generated text and other related details. */ @Override - public ChatResponse generate(Prompt prompt) { + public ChatResponse call(Prompt prompt) { GenerateRequest generateRequest = new GenerateRequest(); generateRequest.setInputs(prompt.getContents()); GenerateParameters generateParameters = new GenerateParameters(); diff --git a/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/client/ClientIT.java b/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/client/ClientIT.java index 7c07a61e340..05f814a5fc8 100644 --- a/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/client/ClientIT.java +++ b/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/client/ClientIT.java @@ -22,7 +22,7 @@ import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.huggingface.HuggingfaceChatClient; -import org.springframework.ai.prompt.Prompt; +import org.springframework.ai.chat.prompt.Prompt; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.test.context.SpringBootTest; @@ -47,8 +47,8 @@ void helloWorldCompletion() { [/INST] """; Prompt prompt = new Prompt(mistral7bInstruct); - ChatResponse chatResponse = huggingfaceChatClient.generate(prompt); - assertThat(chatResponse.getGeneration().getContent()).isNotEmpty(); + ChatResponse chatResponse = huggingfaceChatClient.call(prompt); + assertThat(chatResponse.getResult().getOutput().getContent()).isNotEmpty(); String expectedResponse = """ ```json { @@ -57,9 +57,9 @@ void helloWorldCompletion() { "address": "#1 Samuel St." } ```"""; - assertThat(chatResponse.getGeneration().getContent()).isEqualTo(expectedResponse); - assertThat(chatResponse.getGeneration().getProperties()).containsKey("generated_tokens"); - assertThat(chatResponse.getGeneration().getProperties()).containsEntry("generated_tokens", 39); + assertThat(chatResponse.getResult().getOutput().getContent()).isEqualTo(expectedResponse); + assertThat(chatResponse.getResult().getOutput().getProperties()).containsKey("generated_tokens"); + assertThat(chatResponse.getResult().getOutput().getProperties()).containsEntry("generated_tokens", 39); } diff --git a/models/spring-ai-ollama/pom.xml b/models/spring-ai-ollama/pom.xml index 0da9e9a5353..ad143f9afb0 100644 --- a/models/spring-ai-ollama/pom.xml +++ b/models/spring-ai-ollama/pom.xml @@ -21,6 +21,12 @@ + + + org.springframework.boot + spring-boot + + org.springframework.ai spring-ai-core diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatClient.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatClient.java index 3c72a5d42a1..e1835c67ac1 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatClient.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatClient.java @@ -16,25 +16,30 @@ package org.springframework.ai.ollama; +import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import reactor.core.publisher.Flux; import org.springframework.ai.chat.ChatClient; import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.Generation; import org.springframework.ai.chat.StreamingChatClient; -import org.springframework.ai.metadata.ChoiceMetadata; -import org.springframework.ai.metadata.Usage; +import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.ollama.api.OllamaApi; import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.ai.ollama.api.OllamaApi.ChatRequest; import org.springframework.ai.ollama.api.OllamaApi.Message.Role; -import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.messages.Message; -import org.springframework.ai.prompt.messages.MessageType; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.MessageType; /** * {@link ChatClient} implementation for {@literal Ollma}. @@ -58,6 +63,8 @@ public class OllamaChatClient implements ChatClient, StreamingChatClient { private Map clientOptions; + private final static ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + public OllamaChatClient(OllamaApi chatApi) { this.chatApi = chatApi; } @@ -78,12 +85,13 @@ public OllamaChatClient withOptions(OllamaOptions options) { } @Override - public ChatResponse generate(Prompt prompt) { + public ChatResponse call(Prompt prompt) { OllamaApi.ChatResponse response = this.chatApi.chat(request(prompt, this.model, false)); var generator = new Generation(response.message().content()); if (response.promptEvalCount() != null && response.evalCount() != null) { - generator = generator.withChoiceMetadata(ChoiceMetadata.from("unknown", extractUsage(response))); + generator = generator + .withGenerationMetadata(ChatGenerationMetadata.from("unknown", extractUsage(response))); } return new ChatResponse(List.of(generator)); } @@ -97,7 +105,8 @@ public Flux generateStream(Prompt prompt) { Generation generation = (chunk.message() != null) ? new Generation(chunk.message().content()) : new Generation(""); if (Boolean.TRUE.equals(chunk.done())) { - generation = generation.withChoiceMetadata(ChoiceMetadata.from("unknown", extractUsage(chunk))); + generation = generation + .withGenerationMetadata(ChatGenerationMetadata.from("unknown", extractUsage(chunk))); } return new ChatResponse(List.of(generation)); }); @@ -120,20 +129,57 @@ public Long getGenerationTokens() { private OllamaApi.ChatRequest request(Prompt prompt, String model, boolean stream) { - List ollamaMessages = prompt.getMessages() + List ollamaMessages = prompt.getInstructions() .stream() .filter(message -> message.getMessageType() == MessageType.USER || message.getMessageType() == MessageType.ASSISTANT) .map(m -> OllamaApi.Message.builder(toRole(m)).withContent(m.getContent()).build()) .toList(); + // runtime options + Map promptOptions = objectToMap(prompt.getOptions()); + Map clientOptionsToUse = merge(promptOptions, this.clientOptions, HashMap.class); + return ChatRequest.builder(model) .withStream(stream) .withMessages(ollamaMessages) - .withOptions(this.clientOptions) + .withOptions(clientOptionsToUse) .build(); } + public static Map objectToMap(Object source) { + try { + String json = OBJECT_MAPPER.writeValueAsString(source); + return OBJECT_MAPPER.readValue(json, new TypeReference>() { + }); + } + catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + public static T mapToClass(Map source, Class clazz) { + try { + String json = OBJECT_MAPPER.writeValueAsString(source); + return OBJECT_MAPPER.readValue(json, clazz); + } + catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + public static T merge(Object source, Object target, Class clazz) { + Map sourceMap = objectToMap(source); + Map targetMap = objectToMap(target); + + targetMap.putAll(sourceMap.entrySet() + .stream() + .filter(e -> e.getValue() != null) + .collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue()))); + + return mapToClass(targetMap, clazz); + } + private OllamaApi.Message.Role toRole(Message message) { switch (message.getMessageType()) { diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java index 172b08fe52b..5bc3f6c6f3c 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java @@ -25,6 +25,7 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; +import org.springframework.ai.chat.ChatOptions; /** * Helper class for creating strongly-typed Ollama options. @@ -38,7 +39,7 @@ * Types */ @JsonInclude(Include.NON_NULL) -public class OllamaOptions { +public class OllamaOptions implements ChatOptions { // @formatter:off /** diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatClientIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatClientIT.java index 88a184e6f7e..84617fea0bc 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatClientIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatClientIT.java @@ -11,6 +11,8 @@ import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.ChatOptionsBuilder; +import org.springframework.ai.chat.messages.AssistantMessage; import org.testcontainers.containers.GenericContainer; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; @@ -22,11 +24,11 @@ import org.springframework.ai.parser.BeanOutputParser; import org.springframework.ai.parser.ListOutputParser; import org.springframework.ai.parser.MapOutputParser; -import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.PromptTemplate; -import org.springframework.ai.prompt.SystemPromptTemplate; -import org.springframework.ai.prompt.messages.Message; -import org.springframework.ai.prompt.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; @@ -51,7 +53,7 @@ class OllamaChatClientIT { @BeforeAll public static void beforeAll() throws IOException, InterruptedException { - logger.info("Start pulling the '" + MODEL + " ' model ... would take several minutes ..."); + logger.info("Start pulling the '" + MODEL + " ' generative ... would take several minutes ..."); ollamaContainer.execInContainer("ollama", "pull", MODEL); logger.info(MODEL + " pulling competed!"); @@ -72,9 +74,16 @@ void roleTest() { UserMessage userMessage = new UserMessage("Tell me about 5 famous pirates from the Golden Age of Piracy."); - Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - ChatResponse response = client.generate(prompt); - assertThat(response.getGeneration().getContent()).contains("Blackbeard"); + // portable/generic options + var chatOptionsBuilder = ChatOptionsBuilder.builder(); + + // ollama specific options + var ollamaOptions = new OllamaOptions().withLowVRAM(true); + + Prompt prompt = new Prompt(List.of(userMessage, systemMessage), + chatOptionsBuilder.withTemperature(0.7f).build()); + ChatResponse response = client.call(prompt); + assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); } @Disabled("TODO: Fix the parser instructions to return the correct format") @@ -91,9 +100,9 @@ void outputParser() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "ice cream flavors.", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = this.client.generate(prompt).getGeneration(); + Generation generation = this.client.call(prompt).getResult(); - List list = outputParser.parse(generation.getContent()); + List list = outputParser.parse(generation.getOutput().getContent()); assertThat(list).hasSize(5); } @@ -112,9 +121,9 @@ void mapOutputParser() { Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = client.generate(prompt).getGeneration(); + Generation generation = client.call(prompt).getResult(); - Map result = outputParser.parse(generation.getContent()); + Map result = outputParser.parse(generation.getOutput().getContent()); assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); } @@ -136,9 +145,9 @@ void beanOutputParserRecords() { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = client.generate(prompt).getGeneration(); + Generation generation = client.call(prompt).getResult(); - ActorsFilmsRecord actorsFilms = outputParser.parse(generation.getContent()); + ActorsFilmsRecord actorsFilms = outputParser.parse(generation.getOutput().getContent()); assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); assertThat(actorsFilms.movies()).hasSize(5); } @@ -162,9 +171,10 @@ void beanStreamOutputParserRecords() { .collectList() .block() .stream() - .map(ChatResponse::getGenerations) + .map(ChatResponse::getResults) .flatMap(List::stream) - .map(Generation::getContent) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) .collect(Collectors.joining()); ActorsFilmsRecord actorsFilms = outputParser.parse(generationTextFromStream); diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingClientIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingClientIT.java index 393366092b1..8f066d05b95 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingClientIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingClientIT.java @@ -36,7 +36,7 @@ class OllamaEmbeddingClientIT { @BeforeAll public static void beforeAll() throws IOException, InterruptedException { - logger.info("Start pulling the 'orca-mini' model (3GB) ... would take several minutes ..."); + logger.info("Start pulling the 'orca-mini' generative (3GB) ... would take several minutes ..."); ollamaContainer.execInContainer("ollama", "pull", "orca-mini"); logger.info("orca-mini pulling competed!"); diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiIT.java index ae5ed2ea1c0..da34fa83997 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiIT.java @@ -57,7 +57,7 @@ public class OllamaApiIT { @BeforeAll public static void beforeAll() throws IOException, InterruptedException { - logger.info("Start pulling the 'orca-mini' model (3GB) ... would take several minutes ..."); + logger.info("Start pulling the 'orca-mini' generative (3GB) ... would take several minutes ..."); ollamaContainer.execInContainer("ollama", "pull", "orca-mini"); logger.info("orca-mini pulling competed!"); diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaOptionsTests.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaModelOptionsTests.java similarity index 97% rename from models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaOptionsTests.java rename to models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaModelOptionsTests.java index 60c4f098f98..6320854ec35 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaOptionsTests.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaModelOptionsTests.java @@ -25,7 +25,7 @@ /** * @author Christian Tzolov */ -public class OllamaOptionsTests { +public class OllamaModelOptionsTests { @Test public void testOptions() { diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java index 9ea1617411e..64ebc1c45d0 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java @@ -28,17 +28,17 @@ import org.springframework.ai.chat.StreamingChatClient; import org.springframework.ai.chat.Generation; -import org.springframework.ai.metadata.ChoiceMetadata; -import org.springframework.ai.metadata.RateLimit; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; +import org.springframework.ai.chat.metadata.RateLimit; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage; import org.springframework.ai.openai.api.OpenAiApi.OpenAiApiClientErrorException; import org.springframework.ai.openai.api.OpenAiApi.OpenAiApiException; -import org.springframework.ai.openai.metadata.OpenAiGenerationMetadata; +import org.springframework.ai.openai.metadata.OpenAiChatResponseMetadata; import org.springframework.ai.openai.metadata.support.OpenAiResponseHeaderExtractor; -import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.messages.Message; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.messages.Message; import org.springframework.http.ResponseEntity; import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; @@ -94,10 +94,10 @@ public void setTemperature(Double temperature) { } @Override - public ChatResponse generate(Prompt prompt) { + public ChatResponse call(Prompt prompt) { return this.retryTemplate.execute(ctx -> { - List messages = prompt.getMessages(); + List messages = prompt.getInstructions(); List chatCompletionMessages = messages.stream() .map(m -> new ChatCompletionMessage(m.getContent(), @@ -118,18 +118,18 @@ public ChatResponse generate(Prompt prompt) { List generations = chatCompletion.choices().stream().map(choice -> { return new Generation(choice.message().content(), Map.of("role", choice.message().role().name())) - .withChoiceMetadata(ChoiceMetadata.from(choice.finishReason().name(), null)); + .withGenerationMetadata(ChatGenerationMetadata.from(choice.finishReason().name(), null)); }).toList(); return new ChatResponse(generations, - OpenAiGenerationMetadata.from(completionEntity.getBody()).withRateLimit(rateLimits)); + OpenAiChatResponseMetadata.from(completionEntity.getBody()).withRateLimit(rateLimits)); }); } @Override public Flux generateStream(Prompt prompt) { return this.retryTemplate.execute(ctx -> { - List messages = prompt.getMessages(); + List messages = prompt.getInstructions(); List chatCompletionMessages = messages.stream() .map(m -> new ChatCompletionMessage(m.getContent(), @@ -153,7 +153,7 @@ public Flux generateStream(Prompt prompt) { var generation = new Generation(choice.delta().content(), Map.of("role", roleMap.get(chunkId))); if (choice.finishReason() != null) { generation = generation - .withChoiceMetadata(ChoiceMetadata.from(choice.finishReason().name(), null)); + .withGenerationMetadata(ChatGenerationMetadata.from(choice.finishReason().name(), null)); } return generation; }).toList(); diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageClient.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageClient.java new file mode 100644 index 00000000000..db829b18369 --- /dev/null +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageClient.java @@ -0,0 +1,126 @@ +package org.springframework.ai.openai; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.image.*; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.openai.api.*; +import org.springframework.ai.openai.metadata.OpenAiImageGenerationMetadata; +import org.springframework.ai.openai.metadata.OpenAiImageResponseMetadata; +import org.springframework.http.ResponseEntity; +import org.springframework.retry.support.RetryTemplate; +import org.springframework.util.Assert; + +import java.time.Duration; +import java.util.List; + +public class OpenAiImageClient implements ImageClient { + + private final Logger logger = LoggerFactory.getLogger(getClass()); + + private OpenAiImageOptions options; + + private final OpenAiImageApi openAiImageApi; + + public final RetryTemplate retryTemplate = RetryTemplate.builder() + .maxAttempts(10) + .retryOn(OpenAiApi.OpenAiApiException.class) + .exponentialBackoff(Duration.ofMillis(2000), 5, Duration.ofMillis(3 * 60000)) + .build(); + + public OpenAiImageClient(OpenAiImageApi openAiImageApi) { + Assert.notNull(openAiImageApi, "OpenAiImageApi must not be null"); + this.openAiImageApi = openAiImageApi; + } + + public OpenAiImageOptions getOptions() { + return options; + } + + @Override + public ImageResponse call(ImagePrompt imagePrompt) { + return this.retryTemplate.execute(ctx -> { + ImageOptions runtimeOptions = imagePrompt.getOptions(); + OpenAiImageOptions imageOptionsToUse = updateImageOptions(imagePrompt.getOptions()); + + // Merge the runtime options passed via the prompt with the + // StabilityAiImageClient + // options configured via Autoconfiguration. + // Runtime options overwrite StabilityAiImageClient options + OpenAiImageOptions optionsToUse = ModelOptionsUtils.merge(runtimeOptions, this.options, + OpenAiImageOptionsImpl.class); + + // Copy the org.springframework.ai.model derived ImagePrompt and ImageOptions + // data + // types to the data types used in OpenAiImageApi + String instructions = imagePrompt.getInstructions().get(0).getText(); + OpenAiImageApi.OpenAiImageRequest openAiImageRequest = new OpenAiImageApi.OpenAiImageRequest(instructions, + imageOptionsToUse.getModel(), imageOptionsToUse.getN(), imageOptionsToUse.getQuality(), + imageOptionsToUse.getWidth() + "x" + imageOptionsToUse.getHeight(), + imageOptionsToUse.getResponseFormat(), imageOptionsToUse.getStyle(), imageOptionsToUse.getUser()); + + // Make the request + ResponseEntity imageResponseEntity = this.openAiImageApi + .createImage(openAiImageRequest); + + // Convert to org.springframework.ai.model derived ImageResponse data type + return convertResponse(imageResponseEntity, openAiImageRequest); + + }); + } + + private ImageResponse convertResponse(ResponseEntity imageResponseEntity, + OpenAiImageApi.OpenAiImageRequest openAiImageRequest) { + OpenAiImageApi.OpenAiImageResponse imageApiResponse = imageResponseEntity.getBody(); + if (imageApiResponse == null) { + logger.warn("No image response returned for request: {}", openAiImageRequest); + return new ImageResponse(List.of()); + } + + List imageGenerationList = imageApiResponse.data().stream().map(entry -> { + return new ImageGeneration(new Image(entry.url(), entry.b64Json()), + new OpenAiImageGenerationMetadata(entry.revisedPrompt())); + }).toList(); + + ImageResponseMetadata openAiImageResponseMetadata = OpenAiImageResponseMetadata.from(imageApiResponse); + return new ImageResponse(imageGenerationList, openAiImageResponseMetadata); + } + + private OpenAiImageOptions updateImageOptions(ImageOptions runtimeImageOptions) { + OpenAiImageOptionsBuilder openAiImageOptionsBuilder = OpenAiImageOptionsBuilder.builder(); + if (runtimeImageOptions != null) { + // Handle portable image options + if (runtimeImageOptions.getN() != null) { + openAiImageOptionsBuilder.withN(runtimeImageOptions.getN()); + } + if (runtimeImageOptions.getModel() != null) { + openAiImageOptionsBuilder.withModel(runtimeImageOptions.getModel()); + } + if (runtimeImageOptions.getResponseFormat() != null) { + openAiImageOptionsBuilder.withResponseFormat(runtimeImageOptions.getResponseFormat()); + } + if (runtimeImageOptions.getWidth() != null) { + openAiImageOptionsBuilder.withWidth(runtimeImageOptions.getWidth()); + } + if (runtimeImageOptions.getHeight() != null) { + openAiImageOptionsBuilder.withHeight(runtimeImageOptions.getHeight()); + } + // Handle OpenAI specific image options + if (runtimeImageOptions instanceof OpenAiImageOptions) { + OpenAiImageOptions runtimeOpenAiImageOptions = (OpenAiImageOptions) runtimeImageOptions; + if (runtimeOpenAiImageOptions.getQuality() != null) { + openAiImageOptionsBuilder.withQuality(runtimeOpenAiImageOptions.getQuality()); + } + if (runtimeOpenAiImageOptions.getStyle() != null) { + openAiImageOptionsBuilder.withStyle(runtimeOpenAiImageOptions.getStyle()); + } + if (runtimeOpenAiImageOptions.getUser() != null) { + openAiImageOptionsBuilder.withUser(runtimeOpenAiImageOptions.getUser()); + } + } + } + OpenAiImageOptions updatedOpenAiImageOptions = openAiImageOptionsBuilder.build(); + return updatedOpenAiImageOptions; + } + +} diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java index a31c028872a..55cda04321a 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java @@ -69,7 +69,7 @@ public OpenAiApi(String openAiToken) { } /** - * Create an new chat completion api. + * Create a new chat completion api. * * @param baseUrl api base URL. * @param openAiToken OpenAI apiKey. diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageApi.java new file mode 100644 index 00000000000..cf02c33a8ef --- /dev/null +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageApi.java @@ -0,0 +1,250 @@ +package org.springframework.ai.openai.api; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.springframework.ai.image.ImageResponse; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.http.client.ClientHttpResponse; +import org.springframework.util.Assert; +import org.springframework.web.client.ResponseErrorHandler; +import org.springframework.web.client.RestClient; +import org.springframework.web.reactive.function.client.WebClient; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; +import java.util.function.Consumer; + +public class OpenAiImageApi { + + private static final String DEFAULT_BASE_URL = "https://api.openai.com"; + + public static final String DEFAULT_IMAGE_MODEL = "dall-e-2"; + + // Assuming RestClient and WebClient are properly defined somewhere + private final RestClient restClient; + + private final WebClient webClient; + + private final ObjectMapper objectMapper; + + /** + * Create a new OpenAI Image api with base URL set to https://api.openai.com + * @param openAiToken OpenAI apiKey. + */ + public OpenAiImageApi(String openAiToken) { + this(DEFAULT_BASE_URL, openAiToken, RestClient.builder()); + } + + public OpenAiImageApi(String baseUrl, String openAiToken, RestClient.Builder restClientBuilder) { + + this.objectMapper = new ObjectMapper().configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); + + Consumer jsonContentHeaders = headers -> { + headers.setBearerAuth(openAiToken); + headers.setContentType(MediaType.APPLICATION_JSON); + }; + + var responseErrorHandler = new ResponseErrorHandler() { + + @Override + public boolean hasError(ClientHttpResponse response) throws IOException { + return response.getStatusCode().isError(); + } + + @Override + public void handleError(ClientHttpResponse response) throws IOException { + if (response.getStatusCode().isError()) { + throw new OpenAiApi.OpenAiApiException(String.format("%s - %s", response.getStatusCode().value(), + new ObjectMapper().readValue(response.getBody(), OpenAiApi.ResponseError.class))); + } + } + }; + + this.restClient = restClientBuilder.baseUrl(baseUrl) + .defaultHeaders(jsonContentHeaders) + .defaultStatusHandler(responseErrorHandler) + .build(); + + this.webClient = WebClient.builder().baseUrl(baseUrl).defaultHeaders(jsonContentHeaders).build(); + } + + public static class OpenAiImageApiException extends RuntimeException { + + public OpenAiImageApiException(String message) { + super(message); + } + + public OpenAiImageApiException(String message, Throwable cause) { + super(message, cause); + } + + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + public static class ImageResponseError { + + private final Error error; + + public ImageResponseError(@JsonProperty("error") Error error) { + this.error = error; + } + + public Error getError() { + return error; + } + + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + public static class OpenAiImageRequest { + + @JsonProperty("prompt") + private String prompt; + + @JsonProperty("model") + private String model = DEFAULT_IMAGE_MODEL; + + @JsonProperty("n") + private Integer n; + + @JsonProperty("quality") + private String quality; + + @JsonProperty("response_format") + private String responseFormat; + + @JsonProperty("size") + private String size; + + @JsonProperty("style") + private String style; + + @JsonProperty("user") + private String user; + + public OpenAiImageRequest() { + } + + public OpenAiImageRequest(String prompt, String model, Integer n, String quality, String size, + String responseFormat, String style, String user) { + this.prompt = prompt; + this.model = model; + this.n = n; + this.quality = quality; + this.size = size; + this.responseFormat = responseFormat; + this.style = style; + this.user = user; + } + + public String getPrompt() { + return prompt; + } + + public void setPrompt(String prompt) { + this.prompt = prompt; + } + + public String getModel() { + return model; + } + + public void setModel(String model) { + this.model = model; + } + + public Integer getN() { + return n; + } + + public void setN(Integer n) { + this.n = n; + } + + public String getQuality() { + return quality; + } + + public void setQuality(String quality) { + this.quality = quality; + } + + public String getSize() { + return size; + } + + public void setSize(String size) { + this.size = size; + } + + public String getResponseFormat() { + return responseFormat; + } + + public void setResponseFormat(String responseFormat) { + this.responseFormat = responseFormat; + } + + public String getStyle() { + return style; + } + + public void setStyle(String style) { + this.style = style; + } + + public String getUser() { + return user; + } + + public void setUser(String user) { + this.user = user; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (!(o instanceof OpenAiImageRequest that)) + return false; + return Objects.equals(prompt, that.prompt) && Objects.equals(model, that.model) && Objects.equals(n, that.n) + && Objects.equals(quality, that.quality) && Objects.equals(size, that.size) + && Objects.equals(responseFormat, that.responseFormat) && Objects.equals(style, that.style) + && Objects.equals(user, that.user); + } + + @Override + public int hashCode() { + return Objects.hash(prompt, model, n, quality, size, responseFormat, style, user); + } + + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + public record OpenAiImageResponse(@JsonProperty("created") Long created, @JsonProperty("data") List data) { + + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + public record Data(@JsonProperty("url") String url, @JsonProperty("b64_json") String b64Json, + @JsonProperty("revised_prompt") String revisedPrompt) { + + } + + public ResponseEntity createImage(OpenAiImageRequest openAiImageRequest) { + Assert.notNull(openAiImageRequest, "Image request cannot be null."); + Assert.hasLength(openAiImageRequest.getPrompt(), "Prompt cannot be empty."); + + return this.restClient.post() + .uri("v1/images/generations") + .body(openAiImageRequest) + .retrieve() + .toEntity(OpenAiImageResponse.class); + } + +} diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageOptions.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageOptions.java new file mode 100644 index 00000000000..327f5dcbb56 --- /dev/null +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageOptions.java @@ -0,0 +1,13 @@ +package org.springframework.ai.openai.api; + +import org.springframework.ai.image.ImageOptions; + +public interface OpenAiImageOptions extends ImageOptions { + + String getQuality(); + + String getStyle(); + + String getUser(); + +} diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageOptionsBuilder.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageOptionsBuilder.java new file mode 100644 index 00000000000..b896a2f9e5e --- /dev/null +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageOptionsBuilder.java @@ -0,0 +1,59 @@ +package org.springframework.ai.openai.api; + +public class OpenAiImageOptionsBuilder { + + private final OpenAiImageOptionsImpl options; + + private OpenAiImageOptionsBuilder() { + this.options = new OpenAiImageOptionsImpl(); + } + + public static OpenAiImageOptionsBuilder builder() { + return new OpenAiImageOptionsBuilder(); + } + + public OpenAiImageOptionsBuilder withN(Integer n) { + options.setN(n); + return this; + } + + public OpenAiImageOptionsBuilder withModel(String model) { + options.setModel(model); + return this; + } + + public OpenAiImageOptionsBuilder withQuality(String quality) { + options.setQuality(quality); + return this; + } + + public OpenAiImageOptionsBuilder withResponseFormat(String responseFormat) { + options.setResponseFormat(responseFormat); + return this; + } + + public OpenAiImageOptionsBuilder withWidth(Integer width) { + options.setWidth(width); + return this; + } + + public OpenAiImageOptionsBuilder withHeight(Integer height) { + options.setHeight(height); + return this; + } + + public OpenAiImageOptionsBuilder withStyle(String style) { + options.setStyle(style); + return this; + } + + public OpenAiImageOptionsBuilder withUser(String user) { + options.setUser(user); + return this; + } + + public OpenAiImageOptions build() { + return options; + } + +} diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageOptionsImpl.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageOptionsImpl.java new file mode 100644 index 00000000000..719c0918cb7 --- /dev/null +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageOptionsImpl.java @@ -0,0 +1,93 @@ +package org.springframework.ai.openai.api; + +public class OpenAiImageOptionsImpl implements OpenAiImageOptions { + + private Integer n; + + private String model; + + private String quality; + + private String responseFormat; + + private Integer width; + + private Integer height; + + private String style; + + private String user; + + @Override + public Integer getN() { + return n; + } + + public void setN(Integer n) { + this.n = n; + } + + @Override + public String getModel() { + return model; + } + + public void setModel(String model) { + this.model = model; + } + + @Override + public String getQuality() { + return quality; + } + + public void setQuality(String quality) { + this.quality = quality; + } + + @Override + public String getResponseFormat() { + return responseFormat; + } + + public void setResponseFormat(String responseFormat) { + this.responseFormat = responseFormat; + } + + @Override + public Integer getWidth() { + return width; + } + + public void setWidth(Integer width) { + this.width = width; + } + + @Override + public Integer getHeight() { + return height; + } + + public void setHeight(Integer height) { + this.height = height; + } + + @Override + public String getStyle() { + return style; + } + + public void setStyle(String style) { + this.style = style; + } + + @Override + public String getUser() { + return user; + } + + public void setUser(String user) { + this.user = user; + } + +} diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiGenerationMetadata.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiChatResponseMetadata.java similarity index 66% rename from models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiGenerationMetadata.java rename to models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiChatResponseMetadata.java index 23c5474b5af..ad797df518f 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiGenerationMetadata.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiChatResponseMetadata.java @@ -16,31 +16,31 @@ package org.springframework.ai.openai.metadata; -import org.springframework.ai.metadata.GenerationMetadata; -import org.springframework.ai.metadata.RateLimit; -import org.springframework.ai.metadata.Usage; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.metadata.RateLimit; +import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.lang.Nullable; import org.springframework.util.Assert; /** - * {@link GenerationMetadata} implementation for {@literal OpenAI}. + * {@link ChatResponseMetadata} implementation for {@literal OpenAI}. * * @author John Blum - * @see org.springframework.ai.metadata.GenerationMetadata - * @see org.springframework.ai.metadata.RateLimit - * @see org.springframework.ai.metadata.Usage + * @see ChatResponseMetadata + * @see RateLimit + * @see Usage * @since 0.7.0 */ -public class OpenAiGenerationMetadata implements GenerationMetadata { +public class OpenAiChatResponseMetadata implements ChatResponseMetadata { protected static final String AI_METADATA_STRING = "{ @type: %1$s, id: %2$s, usage: %3$s, rateLimit: %4$s }"; - public static OpenAiGenerationMetadata from(OpenAiApi.ChatCompletion result) { + public static OpenAiChatResponseMetadata from(OpenAiApi.ChatCompletion result) { Assert.notNull(result, "OpenAI ChatCompletionResult must not be null"); OpenAiUsage usage = OpenAiUsage.from(result.usage()); - OpenAiGenerationMetadata generationMetadata = new OpenAiGenerationMetadata(result.id(), usage); - return generationMetadata; + OpenAiChatResponseMetadata chatResponseMetadata = new OpenAiChatResponseMetadata(result.id(), usage); + return chatResponseMetadata; } private final String id; @@ -50,11 +50,11 @@ public static OpenAiGenerationMetadata from(OpenAiApi.ChatCompletion result) { private final Usage usage; - protected OpenAiGenerationMetadata(String id, OpenAiUsage usage) { + protected OpenAiChatResponseMetadata(String id, OpenAiUsage usage) { this(id, usage, null); } - protected OpenAiGenerationMetadata(String id, OpenAiUsage usage, @Nullable OpenAiRateLimit rateLimit) { + protected OpenAiChatResponseMetadata(String id, OpenAiUsage usage, @Nullable OpenAiRateLimit rateLimit) { this.id = id; this.usage = usage; this.rateLimit = rateLimit; @@ -77,7 +77,7 @@ public Usage getUsage() { return usage != null ? usage : Usage.NULL; } - public OpenAiGenerationMetadata withRateLimit(RateLimit rateLimit) { + public OpenAiChatResponseMetadata withRateLimit(RateLimit rateLimit) { this.rateLimit = rateLimit; return this; } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiImageGenerationMetadata.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiImageGenerationMetadata.java new file mode 100644 index 00000000000..6929d85c3d0 --- /dev/null +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiImageGenerationMetadata.java @@ -0,0 +1,38 @@ +package org.springframework.ai.openai.metadata; + +import org.springframework.ai.image.ImageGenerationMetadata; + +import java.util.Objects; + +public class OpenAiImageGenerationMetadata implements ImageGenerationMetadata { + + private String revisedPrompt; + + public OpenAiImageGenerationMetadata(String revisedPrompt) { + this.revisedPrompt = revisedPrompt; + } + + public String getRevisedPrompt() { + return revisedPrompt; + } + + @Override + public String toString() { + return "OpenAiImageGenerationMetadata{" + "revisedPrompt='" + revisedPrompt + '\'' + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (!(o instanceof OpenAiImageGenerationMetadata that)) + return false; + return Objects.equals(revisedPrompt, that.revisedPrompt); + } + + @Override + public int hashCode() { + return Objects.hash(revisedPrompt); + } + +} diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiImageResponseMetadata.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiImageResponseMetadata.java new file mode 100644 index 00000000000..996c50e2f5b --- /dev/null +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiImageResponseMetadata.java @@ -0,0 +1,46 @@ +package org.springframework.ai.openai.metadata; + +import org.springframework.ai.image.ImageResponseMetadata; +import org.springframework.ai.openai.api.OpenAiImageApi; +import org.springframework.util.Assert; + +import java.util.Objects; + +public class OpenAiImageResponseMetadata implements ImageResponseMetadata { + + private final Long created; + + public static OpenAiImageResponseMetadata from(OpenAiImageApi.OpenAiImageResponse openAiImageResponse) { + Assert.notNull(openAiImageResponse, "OpenAiImageResponse must not be null"); + return new OpenAiImageResponseMetadata(openAiImageResponse.created()); + } + + protected OpenAiImageResponseMetadata(Long created) { + this.created = created; + } + + @Override + public Long created() { + return this.created; + } + + @Override + public String toString() { + return "OpenAiImageResponseMetadata{" + "created=" + created + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (!(o instanceof OpenAiImageResponseMetadata that)) + return false; + return Objects.equals(created, that.created); + } + + @Override + public int hashCode() { + return Objects.hash(created); + } + +} diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiRateLimit.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiRateLimit.java index 4a0697cb397..c740c9d33fe 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiRateLimit.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiRateLimit.java @@ -18,7 +18,7 @@ import java.time.Duration; -import org.springframework.ai.metadata.RateLimit; +import org.springframework.ai.chat.metadata.RateLimit; /** * {@link RateLimit} implementation for {@literal OpenAI}. diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiUsage.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiUsage.java index 1b08d449dda..09361b2d6f9 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiUsage.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiUsage.java @@ -16,7 +16,7 @@ package org.springframework.ai.openai.metadata; -import org.springframework.ai.metadata.Usage; +import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.util.Assert; diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/support/OpenAiResponseHeaderExtractor.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/support/OpenAiResponseHeaderExtractor.java index be6eb4b4bf4..21c28652202 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/support/OpenAiResponseHeaderExtractor.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/support/OpenAiResponseHeaderExtractor.java @@ -26,7 +26,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.springframework.ai.metadata.RateLimit; +import org.springframework.ai.chat.metadata.RateLimit; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion; import org.springframework.ai.openai.metadata.OpenAiRateLimit; import org.springframework.http.ResponseEntity; diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java index 894e4e08002..b14b0ba9391 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java @@ -2,6 +2,7 @@ import org.springframework.ai.embedding.EmbeddingClient; import org.springframework.ai.openai.api.OpenAiApi; +import org.springframework.ai.openai.api.OpenAiImageApi; import org.springframework.boot.SpringBootConfiguration; import org.springframework.context.annotation.Bean; import org.springframework.util.StringUtils; @@ -11,9 +12,12 @@ public class OpenAiTestConfiguration { @Bean public OpenAiApi openAiApi() { - String apiKey = getApiKey(); - OpenAiApi openAiService = new OpenAiApi(apiKey); - return openAiService; + return new OpenAiApi(getApiKey()); + } + + @Bean + public OpenAiImageApi openAiImageApi() { + return new OpenAiImageApi(getApiKey()); } private String getApiKey() { @@ -32,6 +36,13 @@ public OpenAiChatClient openAiChatClient(OpenAiApi api) { return openAiChatClient; } + @Bean + public OpenAiImageClient openAiImageClient(OpenAiImageApi imageApi) { + OpenAiImageClient openAiImageClient = new OpenAiImageClient(imageApi); + // openAiImageClient.setModel("foobar"); + return openAiImageClient; + } + @Bean public EmbeddingClient openAiEmbeddingClient(OpenAiApi api) { return new OpenAiEmbeddingClient(api); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/acme/AcmeIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/acme/AcmeIT.java index 92fc5bd357e..9bdb44d8ec2 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/acme/AcmeIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/acme/AcmeIT.java @@ -15,10 +15,10 @@ import org.springframework.ai.openai.OpenAiChatClient; import org.springframework.ai.openai.OpenAiEmbeddingClient; import org.springframework.ai.openai.testutils.AbstractIT; -import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.SystemPromptTemplate; -import org.springframework.ai.prompt.messages.Message; -import org.springframework.ai.prompt.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.reader.JsonReader; import org.springframework.ai.transformer.splitter.TokenTextSplitter; import org.springframework.ai.vectorstore.SimpleVectorStore; @@ -90,10 +90,10 @@ void acmeChain() { // Create the prompt ad-hoc for now, need to put in system message and user // message via ChatPromptTemplate or some other message building mechanic; - logger.info("Asking AI model to reply to question."); + logger.info("Asking AI generative to reply to question."); Prompt prompt = new Prompt(List.of(systemMessage, userMessage)); logger.info("AI responded."); - ChatResponse response = chatClient.generate(prompt); + ChatResponse response = chatClient.call(prompt); evaluateQuestionAndAnswer(userQuery, response, true); } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatClientIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatClientIT.java index 9697a8b5f23..aa35ba9e04b 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatClientIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatClientIT.java @@ -10,16 +10,17 @@ import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.Generation; +import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.openai.OpenAiTestConfiguration; import org.springframework.ai.openai.testutils.AbstractIT; import org.springframework.ai.parser.BeanOutputParser; import org.springframework.ai.parser.ListOutputParser; import org.springframework.ai.parser.MapOutputParser; -import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.PromptTemplate; -import org.springframework.ai.prompt.SystemPromptTemplate; -import org.springframework.ai.prompt.messages.Message; -import org.springframework.ai.prompt.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.core.convert.support.DefaultConversionService; @@ -41,9 +42,9 @@ void roleTest() { SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - ChatResponse response = openAiChatClient.generate(prompt); - assertThat(response.getGenerations()).hasSize(1); - assertThat(response.getGenerations().get(0).getContent()).contains("Blackbeard"); + ChatResponse response = openAiChatClient.call(prompt); + assertThat(response.getResults()).hasSize(1); + assertThat(response.getResults().get(0).getOutput().getContent()).contains("Blackbeard"); // needs fine tuning... evaluateQuestionAndAnswer(request, response, false); } @@ -60,9 +61,9 @@ void outputParser() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "ice cream flavors", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = this.openAiChatClient.generate(prompt).getGeneration(); + Generation generation = this.openAiChatClient.call(prompt).getResult(); - List list = outputParser.parse(generation.getContent()); + List list = outputParser.parse(generation.getOutput().getContent()); assertThat(list).hasSize(5); } @@ -79,9 +80,9 @@ void mapOutputParser() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = openAiChatClient.generate(prompt).getGeneration(); + Generation generation = openAiChatClient.call(prompt).getResult(); - Map result = outputParser.parse(generation.getContent()); + Map result = outputParser.parse(generation.getOutput().getContent()); assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); } @@ -98,9 +99,9 @@ void beanOutputParser() { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = openAiChatClient.generate(prompt).getGeneration(); + Generation generation = openAiChatClient.call(prompt).getResult(); - ActorsFilms actorsFilms = outputParser.parse(generation.getContent()); + ActorsFilms actorsFilms = outputParser.parse(generation.getOutput().getContent()); } record ActorsFilmsRecord(String actor, List movies) { @@ -118,9 +119,9 @@ void beanOutputParserRecords() { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = openAiChatClient.generate(prompt).getGeneration(); + Generation generation = openAiChatClient.call(prompt).getResult(); - ActorsFilmsRecord actorsFilms = outputParser.parse(generation.getContent()); + ActorsFilmsRecord actorsFilms = outputParser.parse(generation.getOutput().getContent()); System.out.println(actorsFilms); assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); assertThat(actorsFilms.movies()).hasSize(5); @@ -143,9 +144,10 @@ void beanStreamOutputParserRecords() { .collectList() .block() .stream() - .map(ChatResponse::getGenerations) + .map(ChatResponse::getResults) .flatMap(List::stream) - .map(Generation::getContent) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) .collect(Collectors.joining()); ActorsFilmsRecord actorsFilms = outputParser.parse(generationTextFromStream); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatClientWithGenerationMetadataTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatClientWithChatResponseMetadataTests.java similarity index 82% rename from models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatClientWithGenerationMetadataTests.java rename to models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatClientWithChatResponseMetadataTests.java index cd091b92261..8c2b494d9c1 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatClientWithGenerationMetadataTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatClientWithChatResponseMetadataTests.java @@ -22,15 +22,11 @@ import org.junit.jupiter.api.Test; import org.springframework.ai.chat.ChatResponse; -import org.springframework.ai.metadata.ChoiceMetadata; -import org.springframework.ai.metadata.GenerationMetadata; -import org.springframework.ai.metadata.PromptMetadata; -import org.springframework.ai.metadata.RateLimit; -import org.springframework.ai.metadata.Usage; +import org.springframework.ai.chat.metadata.*; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.OpenAiChatClient; import org.springframework.ai.openai.metadata.support.OpenAiApiResponseHeaders; -import org.springframework.ai.prompt.Prompt; +import org.springframework.ai.chat.prompt.Prompt; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.autoconfigure.web.client.RestClientTest; @@ -52,8 +48,8 @@ * @author Christian Tzolov * @since 0.7.0 */ -@RestClientTest(OpenAiChatClientWithGenerationMetadataTests.Config.class) -public class OpenAiChatClientWithGenerationMetadataTests { +@RestClientTest(OpenAiChatClientWithChatResponseMetadataTests.Config.class) +public class OpenAiChatClientWithChatResponseMetadataTests { private static String TEST_API_KEY = "sk-1234567890"; @@ -75,22 +71,22 @@ void aiResponseContainsAiMetadata() { Prompt prompt = new Prompt("Reach for the sky."); - ChatResponse response = this.openAiChatClient.generate(prompt); + ChatResponse response = this.openAiChatClient.call(prompt); assertThat(response).isNotNull(); - GenerationMetadata generationMetadata = response.getGenerationMetadata(); + ChatResponseMetadata chatResponseMetadata = response.getMetadata(); - assertThat(generationMetadata).isNotNull(); + assertThat(chatResponseMetadata).isNotNull(); - Usage usage = generationMetadata.getUsage(); + Usage usage = chatResponseMetadata.getUsage(); assertThat(usage).isNotNull(); assertThat(usage.getPromptTokens()).isEqualTo(9L); assertThat(usage.getGenerationTokens()).isEqualTo(12L); assertThat(usage.getTotalTokens()).isEqualTo(21L); - RateLimit rateLimit = generationMetadata.getRateLimit(); + RateLimit rateLimit = chatResponseMetadata.getRateLimit(); Duration expectedRequestsReset = Duration.ofDays(2L) .plus(Duration.ofHours(16L)) @@ -109,16 +105,16 @@ void aiResponseContainsAiMetadata() { assertThat(rateLimit.getTokensRemaining()).isEqualTo(112_358L); assertThat(rateLimit.getTokensReset()).isEqualTo(expectedTokensReset); - PromptMetadata promptMetadata = response.getPromptMetadata(); + PromptMetadata promptMetadata = response.getMetadata().getPromptMetadata(); assertThat(promptMetadata).isNotNull(); assertThat(promptMetadata).isEmpty(); - response.getGenerations().forEach(generation -> { - ChoiceMetadata choiceMetadata = generation.getChoiceMetadata(); - assertThat(choiceMetadata).isNotNull(); - assertThat(choiceMetadata.getFinishReason()).isEqualTo("STOP"); - assertThat(choiceMetadata.getContentFilterMetadata()).isNull(); + response.getResults().forEach(generation -> { + ChatGenerationMetadata chatGenerationMetadata = generation.getResultMetadata(); + assertThat(chatGenerationMetadata).isNotNull(); + assertThat(chatGenerationMetadata.getFinishReason()).isEqualTo("STOP"); + assertThat(chatGenerationMetadata.getContentFilterMetadata()).isNull(); }); } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/embedding/EmbeddingIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/embedding/EmbeddingIT.java index b2515a6a5c5..4c5d6f9af0c 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/embedding/EmbeddingIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/embedding/EmbeddingIT.java @@ -38,7 +38,7 @@ void simpleEmbedding() { EmbeddingResponse embeddingResponse = embeddingClient.embedForResponse(List.of("Hello World")); assertThat(embeddingResponse.getData()).hasSize(1); assertThat(embeddingResponse.getData().get(0).getEmbedding()).isNotEmpty(); - assertThat(embeddingResponse.getMetadata()).containsEntry("model", "text-embedding-ada-002-v2"); + assertThat(embeddingResponse.getMetadata()).containsEntry("generative", "text-embedding-ada-002-v2"); assertThat(embeddingResponse.getMetadata()).containsEntry("total-tokens", 2); assertThat(embeddingResponse.getMetadata()).containsEntry("prompt-tokens", 2); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageClientIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageClientIT.java new file mode 100644 index 00000000000..389f017be99 --- /dev/null +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageClientIT.java @@ -0,0 +1,25 @@ +package org.springframework.ai.openai.image; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.image.ImagePrompt; +import org.springframework.ai.image.ImageResponse; +import org.springframework.ai.openai.OpenAiTestConfiguration; +import org.springframework.ai.openai.testutils.AbstractIT; +import org.springframework.boot.test.context.SpringBootTest; + +@SpringBootTest(classes = OpenAiTestConfiguration.class) +@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+") +public class OpenAiImageClientIT extends AbstractIT { + + @Test + void imageAsUrlTest() { + ImagePrompt imagePrompt = new ImagePrompt("Create an image of a mini golden doodle dog."); + + ImageResponse imageResponse = openaiImageClient.call(imagePrompt); + + System.out.println(imageResponse); + + } + +} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageClientWithImageResponseMetadataTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageClientWithImageResponseMetadataTests.java new file mode 100644 index 00000000000..fa6aa0a3762 --- /dev/null +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageClientWithImageResponseMetadataTests.java @@ -0,0 +1,142 @@ +/* + * Copyright 2023-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.image; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; +import org.springframework.ai.image.ImageGeneration; +import org.springframework.ai.image.ImagePrompt; +import org.springframework.ai.image.ImageResponse; +import org.springframework.ai.image.ImageResponseMetadata; +import org.springframework.ai.openai.OpenAiImageClient; +import org.springframework.ai.openai.api.OpenAiImageApi; +import org.springframework.ai.openai.metadata.support.OpenAiApiResponseHeaders; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.autoconfigure.web.client.RestClientTest; +import org.springframework.context.annotation.Bean; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; +import org.springframework.test.web.client.MockRestServiceServer; +import org.springframework.web.client.RestClient; + +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.test.web.client.match.MockRestRequestMatchers.*; +import static org.springframework.test.web.client.response.MockRestResponseCreators.withSuccess; + +/** + * @author John Blum + * @author Christian Tzolov + * @since 0.7.0 + */ +@RestClientTest(OpenAiImageClientWithImageResponseMetadataTests.Config.class) +public class OpenAiImageClientWithImageResponseMetadataTests { + + private static String TEST_API_KEY = "sk-1234567890"; + + @Autowired + private OpenAiImageClient openAiImageClient; + + @Autowired + private MockRestServiceServer server; + + @AfterEach + void resetMockServer() { + server.reset(); + } + + @Test + void aiResponseContainsImageRespoinseMetadata() { + + prepareMock(); + + ImagePrompt prompt = new ImagePrompt("Create an image of a mini golden doodle dog."); + + ImageResponse response = this.openAiImageClient.call(prompt); + + assertThat(response).isNotNull(); + List imageGenerations = response.getResults(); + assertThat(imageGenerations).isNotNull(); + assertThat(imageGenerations).hasSize(2); + + ImageResponseMetadata imageResponseMetadata = response.getMetadata(); + + assertThat(imageResponseMetadata).isNotNull(); + + Long created = imageResponseMetadata.created(); + + assertThat(created).isNotNull(); + assertThat(created).isEqualTo(1589478378); + + ImageResponseMetadata responseMetadata = response.getMetadata(); + + assertThat(responseMetadata).isNotNull(); + + } + + private void prepareMock() { + + HttpHeaders httpHeaders = new HttpHeaders(); + httpHeaders.set(OpenAiApiResponseHeaders.REQUESTS_LIMIT_HEADER.getName(), "4000"); + httpHeaders.set(OpenAiApiResponseHeaders.REQUESTS_REMAINING_HEADER.getName(), "999"); + httpHeaders.set(OpenAiApiResponseHeaders.REQUESTS_RESET_HEADER.getName(), "2d16h15m29s"); + httpHeaders.set(OpenAiApiResponseHeaders.TOKENS_LIMIT_HEADER.getName(), "725000"); + httpHeaders.set(OpenAiApiResponseHeaders.TOKENS_REMAINING_HEADER.getName(), "112358"); + httpHeaders.set(OpenAiApiResponseHeaders.TOKENS_RESET_HEADER.getName(), "27h55s451ms"); + + server.expect(requestTo("v1/images/generations")) + .andExpect(method(HttpMethod.POST)) + .andExpect(header(HttpHeaders.AUTHORIZATION, "Bearer " + TEST_API_KEY)) + .andRespond(withSuccess(getJson(), MediaType.APPLICATION_JSON).headers(httpHeaders)); + + } + + private String getJson() { + return """ + { + "created": 1589478378, + "data": [ + { + "url": "https://upload.wikimedia.org/wikipedia/commons/4/4e/Mini_Golden_Doodle.jpg" + }, + { + "url": "https://upload.wikimedia.org/wikipedia/commons/8/85/Goldendoodle_puppy_Marty.jpg" + } + ] + } + """; + } + + @SpringBootConfiguration + static class Config { + + @Bean + public OpenAiImageApi imageGenerationApi(RestClient.Builder builder) { + return new OpenAiImageApi("", TEST_API_KEY, builder); + } + + @Bean + public OpenAiImageClient openAiImageClient(OpenAiImageApi openAiImageApi) { + return new OpenAiImageClient(openAiImageApi); + } + + } + +} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/testutils/AbstractIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/testutils/AbstractIT.java index a41ca889c16..2e878f802b7 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/testutils/AbstractIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/testutils/AbstractIT.java @@ -9,10 +9,11 @@ import org.springframework.ai.chat.ChatClient; import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.StreamingChatClient; -import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.PromptTemplate; -import org.springframework.ai.prompt.messages.Message; -import org.springframework.ai.prompt.messages.SystemMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.image.ImageClient; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.core.io.Resource; @@ -27,6 +28,9 @@ public abstract class AbstractIT { @Autowired protected ChatClient openAiChatClient; + @Autowired + protected ImageClient openaiImageClient; + @Autowired protected StreamingChatClient openStreamingChatClient; @@ -44,7 +48,7 @@ public abstract class AbstractIT { protected void evaluateQuestionAndAnswer(String question, ChatResponse response, boolean factBased) { assertThat(response).isNotNull(); - String answer = response.getGeneration().getContent(); + String answer = response.getResult().getOutput().getContent(); logger.info("Question: " + question); logger.info("Answer:" + answer); PromptTemplate userPromptTemplate = new PromptTemplate(userEvaluatorResource, @@ -58,12 +62,12 @@ protected void evaluateQuestionAndAnswer(String question, ChatResponse response, } Message userMessage = userPromptTemplate.createMessage(); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - String yesOrNo = openAiChatClient.generate(prompt).getGeneration().getContent(); + String yesOrNo = openAiChatClient.call(prompt).getResult().getOutput().getContent(); logger.info("Is Answer related to question: " + yesOrNo); if (yesOrNo.equalsIgnoreCase("no")) { SystemMessage notRelatedSystemMessage = new SystemMessage(qaEvaluatorNotRelatedResource); prompt = new Prompt(List.of(userMessage, notRelatedSystemMessage)); - String reasonForFailure = openAiChatClient.generate(prompt).getGeneration().getContent(); + String reasonForFailure = openAiChatClient.call(prompt).getResult().getOutput().getContent(); fail(reasonForFailure); } else { diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/transformer/MetadataTransformerIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/transformer/MetadataTransformerIT.java index a4e79b3ec40..120b32a7be2 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/transformer/MetadataTransformerIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/transformer/MetadataTransformerIT.java @@ -65,7 +65,7 @@ public class MetadataTransformerIT { Document document2 = new Document( "The Spring Framework is divided into modules. Applications can choose which modules" - + " they need. At the heart are the modules of the core container, including a configuration model and a " + + " they need. At the heart are the modules of the core container, including a configuration generative and a " + "dependency injection mechanism. Beyond that, the Spring Framework provides foundational support " + " for different application architectures, including messaging, transactional data and persistence, " + "and web. It also includes the Servlet-based Spring MVC web framework and, in parallel, the Spring " diff --git a/models/spring-ai-stabilityai/pom.xml b/models/spring-ai-stabilityai/pom.xml new file mode 100644 index 00000000000..d189362de22 --- /dev/null +++ b/models/spring-ai-stabilityai/pom.xml @@ -0,0 +1,60 @@ + + + 4.0.0 + + org.springframework.ai + spring-ai + 0.8.0-SNAPSHOT + ../../pom.xml + + spring-ai-stability-ai + jar + Spring AI Stability AI + Stability AI support + https://github.com/spring-projects/spring-ai + + + https://github.com/spring-projects/spring-ai + git://github.com/spring-projects/spring-ai.git + git@github.com:spring-projects/spring-ai.git + + + + + + + org.springframework.ai + spring-ai-core + ${project.parent.version} + + + + org.springframework + spring-web + ${spring-framework.version} + + + + + + org.springframework + spring-context-support + + + + org.springframework.boot + spring-boot-starter-logging + + + + + org.springframework.ai + spring-ai-test + ${project.version} + test + + + + + diff --git a/models/spring-ai-stabilityai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageClient.java b/models/spring-ai-stabilityai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageClient.java new file mode 100644 index 00000000000..edd67c532a1 --- /dev/null +++ b/models/spring-ai-stabilityai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageClient.java @@ -0,0 +1,158 @@ +package org.springframework.ai.stabilityai; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.image.*; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.stabilityai.api.StabilityAiApi; +import org.springframework.ai.stabilityai.api.StabilityAiImageOptions; +import org.springframework.ai.stabilityai.api.StabilityAiImageOptionsBuilder; +import org.springframework.ai.stabilityai.api.StabilityAiImageOptionsImpl; +import org.springframework.util.Assert; + +import java.util.List; +import java.util.stream.Collectors; + +public class StabilityAiImageClient implements ImageClient { + + private final Logger logger = LoggerFactory.getLogger(getClass()); + + private StabilityAiImageOptions options; + + private final StabilityAiApi stabilityAiApi; + + public StabilityAiImageClient(StabilityAiApi stabilityAiApi) { + this(stabilityAiApi, StabilityAiImageOptionsBuilder.builder().build()); + } + + public StabilityAiImageClient(StabilityAiApi stabilityAiApi, StabilityAiImageOptions options) { + Assert.notNull(stabilityAiApi, "StabilityAiApi must not be null"); + Assert.notNull(options, "StabilityAiImageOptions must not be null"); + this.stabilityAiApi = stabilityAiApi; + this.options = options; + } + + public StabilityAiImageOptions getOptions() { + return options; + } + + /** + * Calls the StabilityAiImageClient with the given StabilityAiImagePrompt and returns + * the ImageResponse. This overloaded call method lets you pass the full set of Prompt + * instructions that StabilityAI supports. + * @param imagePrompt the StabilityAiImagePrompt containing the prompt and image model + * options + * @return the ImageResponse generated by the StabilityAiImageClient + */ + public ImageResponse call(ImagePrompt imagePrompt) { + + ImageOptions runtimeOptions = imagePrompt.getOptions(); + + // Merge the runtime options passed via the prompt with the StabilityAiImageClient + // options configured via Autoconfiguration. + // Runtime options overwrite StabilityAiImageClient options + StabilityAiImageOptions optionsToUse = ModelOptionsUtils.merge(runtimeOptions, this.options, + StabilityAiImageOptionsImpl.class); + + // Copy the org.springframework.ai.model derived ImagePrompt and ImageOptions data + // types to the data types used in StabilityAiApi + StabilityAiApi.GenerateImageRequest generateImageRequest = getGenerateImageRequest(imagePrompt, optionsToUse); + + // Make the request + StabilityAiApi.GenerateImageResponse generateImageResponse = this.stabilityAiApi + .generateImage(generateImageRequest); + + // Convert to org.springframework.ai.model derived ImageResponse data type + return convertResponse(generateImageResponse); + + } + + private static StabilityAiApi.GenerateImageRequest getGenerateImageRequest(ImagePrompt stabilityAiImagePrompt, + StabilityAiImageOptions optionsToUse) { + StabilityAiApi.GenerateImageRequest.Builder builder = new StabilityAiApi.GenerateImageRequest.Builder(); + StabilityAiApi.GenerateImageRequest generateImageRequest = builder + .withTextPrompts(stabilityAiImagePrompt.getInstructions() + .stream() + .map(message -> new StabilityAiApi.GenerateImageRequest.TextPrompts(message.getText(), + message.getWeight())) + .collect(Collectors.toList())) + .withHeight(optionsToUse.getHeight()) + .withWidth(optionsToUse.getWidth()) + .withCfgScale(optionsToUse.getCfgScale()) + .withClipGuidancePreset(optionsToUse.getClipGuidancePreset()) + .withSampler(optionsToUse.getSampler()) + .withSamples(optionsToUse.getSamples()) + .withSeed(optionsToUse.getSeed()) + .withSteps(optionsToUse.getSteps()) + .withStylePreset(optionsToUse.getStylePreset()) + .build(); + return generateImageRequest; + } + + private ImageResponse convertResponse(StabilityAiApi.GenerateImageResponse generateImageResponse) { + List imageGenerationList = generateImageResponse.artifacts().stream().map(entry -> { + return new ImageGeneration(new Image(null, entry.base64()), + new StabilityAiImageGenerationMetadata(entry.finishReason(), entry.seed())); + }).toList(); + + return new ImageResponse(imageGenerationList, ImageResponseMetadata.NULL); + } + + private StabilityAiImageOptions convertOptions(ImageOptions runtimeOptions) { + StabilityAiImageOptionsBuilder builder = StabilityAiImageOptionsBuilder.builder(); + if (runtimeOptions == null) { + return builder.build(); + } + if (runtimeOptions.getN() != null) { + builder.withN(runtimeOptions.getN()); + } + if (runtimeOptions.getModel() != null) { + builder.withModel(runtimeOptions.getModel()); + } + if (runtimeOptions.getResponseFormat() != null) { + builder.withResponseFormat(runtimeOptions.getResponseFormat()); + } + if (runtimeOptions.getWidth() != null) { + builder.withWidth(runtimeOptions.getWidth()); + } + if (runtimeOptions.getHeight() != null) { + builder.withHeight(runtimeOptions.getHeight()); + } + if (runtimeOptions instanceof StabilityAiImageOptions) { + StabilityAiImageOptions stabilityAiImageOptions = (StabilityAiImageOptions) runtimeOptions; + if (stabilityAiImageOptions.getCfgScale() != null) { + builder.withCfgScale(stabilityAiImageOptions.getCfgScale()); + } + if (stabilityAiImageOptions.getClipGuidancePreset() != null) { + builder.withClipGuidancePreset(stabilityAiImageOptions.getClipGuidancePreset()); + } + if (stabilityAiImageOptions.getSampler() != null) { + builder.withSampler(stabilityAiImageOptions.getSampler()); + } + if (stabilityAiImageOptions.getSeed() != null) { + builder.withSeed(stabilityAiImageOptions.getSeed()); + } + if (stabilityAiImageOptions.getSteps() != null) { + builder.withSteps(stabilityAiImageOptions.getSteps()); + } + if (stabilityAiImageOptions.getStylePreset() != null) { + builder.withStylePreset(stabilityAiImageOptions.getStylePreset()); + } + } + return builder.build(); + } + + private ImagePrompt createUpdatedPrompt(ImagePrompt prompt) { + ImageOptions runtimeImageModelOptions = prompt.getOptions(); + ImageOptionsBuilder imageOptionsBuilder = ImageOptionsBuilder.builder(); + + if (runtimeImageModelOptions != null) { + if (runtimeImageModelOptions.getModel() != null) { + imageOptionsBuilder.withModel(runtimeImageModelOptions.getModel()); + } + } + ImageOptions updatedImageModelOptions = imageOptionsBuilder.build(); + return new ImagePrompt(prompt.getInstructions(), updatedImageModelOptions); + } + +} diff --git a/models/spring-ai-stabilityai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageGenerationMetadata.java b/models/spring-ai-stabilityai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageGenerationMetadata.java new file mode 100644 index 00000000000..8eeb4be9ae5 --- /dev/null +++ b/models/spring-ai-stabilityai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageGenerationMetadata.java @@ -0,0 +1,45 @@ +package org.springframework.ai.stabilityai; + +import org.springframework.ai.image.ImageGenerationMetadata; + +import java.util.Objects; + +public class StabilityAiImageGenerationMetadata implements ImageGenerationMetadata { + + private String finishReason; + + private Long seed; + + public StabilityAiImageGenerationMetadata(String finishReason, Long seed) { + this.finishReason = finishReason; + this.seed = seed; + } + + public String getFinishReason() { + return finishReason; + } + + public Long getSeed() { + return seed; + } + + @Override + public String toString() { + return "StabilityAiImageGenerationMetadata{" + "finishReason='" + finishReason + '\'' + ", seed=" + seed + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (!(o instanceof StabilityAiImageGenerationMetadata that)) + return false; + return Objects.equals(finishReason, that.finishReason) && Objects.equals(seed, that.seed); + } + + @Override + public int hashCode() { + return Objects.hash(finishReason, seed); + } + +} diff --git a/models/spring-ai-stabilityai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiApi.java b/models/spring-ai-stabilityai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiApi.java new file mode 100644 index 00000000000..9e15cec7f60 --- /dev/null +++ b/models/spring-ai-stabilityai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiApi.java @@ -0,0 +1,212 @@ +package org.springframework.ai.stabilityai.api; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.http.client.ClientHttpResponse; +import org.springframework.util.Assert; +import org.springframework.web.client.ResponseErrorHandler; +import org.springframework.web.client.RestClient; + +import java.io.IOException; +import java.util.List; +import java.util.function.Consumer; + +public class StabilityAiApi { + + public static final String DEFAULT_IMAGE_MODEL = "stable-diffusion-v1-6"; + + public static final String DEFAULT_BASE_URL = "https://api.stability.ai/v1"; + + private final RestClient restClient; + + private final String apiKey; + + private final String model; + + /** + * Create a new StabilityAI API. + * @param apiKey StabilityAI apiKey. + */ + public StabilityAiApi(String apiKey) { + this(apiKey, DEFAULT_IMAGE_MODEL, DEFAULT_BASE_URL, RestClient.builder()); + } + + public StabilityAiApi(String apiKey, String model) { + this(apiKey, model, DEFAULT_BASE_URL, RestClient.builder()); + } + + public StabilityAiApi(String apiKey, String model, String baseUrl) { + this(apiKey, model, baseUrl, RestClient.builder()); + } + + /** + * Create a new StabilityAI API. + * @param apiKey StabilityAI apiKey. + * @param model StabilityAI model. + * @param baseUrl api base URL. + * @param restClientBuilder RestClient builder. + */ + public StabilityAiApi(String apiKey, String model, String baseUrl, RestClient.Builder restClientBuilder) { + + this.model = model; + this.apiKey = apiKey; + + Consumer jsonContentHeaders = headers -> { + headers.setBearerAuth(apiKey); + headers.setAccept(List.of(MediaType.APPLICATION_JSON)); // base64 in JSON + + // metadata or return + // image in bytes. + headers.setContentType(MediaType.APPLICATION_JSON); + }; + + ResponseErrorHandler responseErrorHandler = new ResponseErrorHandler() { + @Override + public boolean hasError(ClientHttpResponse response) throws IOException { + return response.getStatusCode().isError(); + } + + @Override + public void handleError(ClientHttpResponse response) throws IOException { + if (response.getStatusCode().isError()) { + throw new RuntimeException(String.format("%s - %s", response.getStatusCode().value(), + new ObjectMapper().readValue(response.getBody(), ResponseError.class))); + } + } + }; + + this.restClient = restClientBuilder.baseUrl(baseUrl) + .defaultHeaders(jsonContentHeaders) + .defaultStatusHandler(responseErrorHandler) + .build(); + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + public record ResponseError(@JsonProperty("id") String id, @JsonProperty("name") String name, + @JsonProperty("message") String message + + ) { + + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + public record GenerateImageRequest(@JsonProperty("text_prompts") List textPrompts, + @JsonProperty("height") Integer height, @JsonProperty("width") Integer width, + @JsonProperty("cfg_scale") Float cfgScale, @JsonProperty("clip_guidance_preset") String clipGuidancePreset, + @JsonProperty("sampler") String sampler, @JsonProperty("samples") Integer samples, + @JsonProperty("seed") Long seed, @JsonProperty("steps") Integer steps, + @JsonProperty("style_present") String stylePreset) { + + @JsonInclude(JsonInclude.Include.NON_NULL) + public record TextPrompts(@JsonProperty("text") String text, @JsonProperty("weight") Float weight) { + + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + List textPrompts; + + Integer height; + + Integer width; + + Float cfgScale; + + String clipGuidancePreset; + + String sampler; + + Integer samples; + + Long seed; + + Integer steps; + + String stylePreset; + + public Builder() { + + } + + public Builder withTextPrompts(List textPrompts) { + this.textPrompts = textPrompts; + return this; + } + + public Builder withHeight(Integer height) { + this.height = height; + return this; + } + + public Builder withWidth(Integer width) { + this.width = width; + return this; + } + + public Builder withCfgScale(Float cfgScale) { + this.cfgScale = cfgScale; + return this; + } + + public Builder withClipGuidancePreset(String clipGuidancePreset) { + this.clipGuidancePreset = clipGuidancePreset; + return this; + } + + public Builder withSampler(String sampler) { + this.sampler = sampler; + return this; + } + + public Builder withSamples(Integer samples) { + this.samples = samples; + return this; + } + + public Builder withSeed(Long seed) { + this.seed = seed; + return this; + } + + public Builder withSteps(Integer steps) { + this.steps = steps; + return this; + } + + public Builder withStylePreset(String stylePreset) { + this.stylePreset = stylePreset; + return this; + } + + public GenerateImageRequest build() { + return new GenerateImageRequest(textPrompts, height, width, cfgScale, clipGuidancePreset, sampler, + samples, seed, steps, stylePreset); + } + + } + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + public record GenerateImageResponse(@JsonProperty("result") String result, + @JsonProperty("artifacts") List artifacts) { + public record Artifacts(@JsonProperty("seed") long seed, @JsonProperty("base64") String base64, + @JsonProperty("finishReason") String finishReason) { + } + } + + public GenerateImageResponse generateImage(GenerateImageRequest request) { + Assert.notNull(request, "The request body can not be null."); + return this.restClient.post() + .uri("/generation/{model}/text-to-image", this.model) + .body(request) + .retrieve() + .body(GenerateImageResponse.class); + } + +} diff --git a/models/spring-ai-stabilityai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiImageOptions.java b/models/spring-ai-stabilityai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiImageOptions.java new file mode 100644 index 00000000000..358cfd8bc70 --- /dev/null +++ b/models/spring-ai-stabilityai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiImageOptions.java @@ -0,0 +1,23 @@ +package org.springframework.ai.stabilityai.api; + +import org.springframework.ai.image.ImageOptions; + +public interface StabilityAiImageOptions extends ImageOptions { + + Float getCfgScale(); + + String getClipGuidancePreset(); + + String getSampler(); + + Integer getSamples(); + + Long getSeed(); + + Integer getSteps(); + + String getStylePreset(); + + // extras json object... + +} diff --git a/models/spring-ai-stabilityai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiImageOptionsBuilder.java b/models/spring-ai-stabilityai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiImageOptionsBuilder.java new file mode 100644 index 00000000000..17f5bb9beb6 --- /dev/null +++ b/models/spring-ai-stabilityai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiImageOptionsBuilder.java @@ -0,0 +1,79 @@ +package org.springframework.ai.stabilityai.api; + +public class StabilityAiImageOptionsBuilder { + + private StabilityAiImageOptionsImpl options; + + private StabilityAiImageOptionsBuilder() { + options = new StabilityAiImageOptionsImpl(); + } + + public static StabilityAiImageOptionsBuilder builder() { + return new StabilityAiImageOptionsBuilder(); + } + + public StabilityAiImageOptionsBuilder withN(Integer n) { + options.setN(n); + return this; + } + + public StabilityAiImageOptionsBuilder withModel(String model) { + options.setModel(model); + return this; + } + + public StabilityAiImageOptionsBuilder withWidth(Integer width) { + options.setWidth(width); + return this; + } + + public StabilityAiImageOptionsBuilder withHeight(Integer height) { + options.setHeight(height); + return this; + } + + public StabilityAiImageOptionsBuilder withResponseFormat(String responseFormat) { + options.setResponseFormat(responseFormat); + return this; + } + + public StabilityAiImageOptionsBuilder withCfgScale(Float cfgScale) { + options.setCfgScale(cfgScale); + return this; + } + + public StabilityAiImageOptionsBuilder withClipGuidancePreset(String clipGuidancePreset) { + options.setClipGuidancePreset(clipGuidancePreset); + return this; + } + + public StabilityAiImageOptionsBuilder withSampler(String sampler) { + options.setSampler(sampler); + return this; + } + + public StabilityAiImageOptionsBuilder withSeed(Long seed) { + options.setSeed(seed); + return this; + } + + public StabilityAiImageOptionsBuilder withSteps(Integer steps) { + options.setSteps(steps); + return this; + } + + public StabilityAiImageOptionsBuilder withSamples(Integer samples) { + options.setSamples(samples); + return this; + } + + public StabilityAiImageOptionsBuilder withStylePreset(String stylePreset) { + options.setStylePreset(stylePreset); + return this; + } + + public StabilityAiImageOptions build() { + return options; + } + +} diff --git a/models/spring-ai-stabilityai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiImageOptionsImpl.java b/models/spring-ai-stabilityai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiImageOptionsImpl.java new file mode 100644 index 00000000000..0f7213ea860 --- /dev/null +++ b/models/spring-ai-stabilityai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiImageOptionsImpl.java @@ -0,0 +1,140 @@ +package org.springframework.ai.stabilityai.api; + +public class StabilityAiImageOptionsImpl implements StabilityAiImageOptions { + + private Integer n; + + private String model; + + private Integer width; + + private Integer height; + + private String responseFormat; + + private Float cfgScale; + + private String clipGuidancePreset; + + private String sampler; + + private Integer samples; + + private Long seed; + + private Integer steps; + + private String stylePreset; + + public StabilityAiImageOptionsImpl() { + } + + @Override + public Integer getN() { + return this.n; + } + + public void setN(Integer n) { + this.n = n; + } + + @Override + public String getModel() { + return this.model; + } + + public void setModel(String model) { + this.model = model; + } + + @Override + public Integer getWidth() { + return this.width; + } + + public void setWidth(Integer width) { + this.width = width; + } + + @Override + public Integer getHeight() { + return this.height; + } + + public void setHeight(Integer height) { + this.height = height; + } + + @Override + public String getResponseFormat() { + return this.responseFormat; + } + + public void setResponseFormat(String responseFormat) { + this.responseFormat = responseFormat; + } + + @Override + public Float getCfgScale() { + return this.cfgScale; + } + + public void setCfgScale(Float cfgScale) { + this.cfgScale = cfgScale; + } + + @Override + public String getClipGuidancePreset() { + return this.clipGuidancePreset; + } + + public void setClipGuidancePreset(String clipGuidancePreset) { + this.clipGuidancePreset = clipGuidancePreset; + } + + @Override + public String getSampler() { + return this.sampler; + } + + public void setSampler(String sampler) { + this.sampler = sampler; + } + + @Override + public Integer getSamples() { + return this.samples; + } + + public void setSamples(Integer samples) { + this.samples = samples; + } + + @Override + public Long getSeed() { + return this.seed; + } + + public void setSeed(Long seed) { + this.seed = seed; + } + + @Override + public Integer getSteps() { + return this.steps; + } + + public void setSteps(Integer steps) { + this.steps = steps; + } + + @Override + public String getStylePreset() { + return this.stylePreset; + } + + public void setStylePreset(String stylePreset) { + this.stylePreset = stylePreset; + } + +} diff --git a/models/spring-ai-stabilityai/src/test/java/org/springframework/ai/stabilityai/StabilityAiApiIT.java b/models/spring-ai-stabilityai/src/test/java/org/springframework/ai/stabilityai/StabilityAiApiIT.java new file mode 100644 index 00000000000..d70f0f071c0 --- /dev/null +++ b/models/spring-ai-stabilityai/src/test/java/org/springframework/ai/stabilityai/StabilityAiApiIT.java @@ -0,0 +1,64 @@ +package org.springframework.ai.stabilityai; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.stabilityai.api.StabilityAiApi; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.util.Base64; +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; + +@EnabledIfEnvironmentVariable(named = "STABILITYAI_API_KEY", matches = ".*") +public class StabilityAiApiIT { + + StabilityAiApi stabilityAiApi = new StabilityAiApi(System.getenv("STABILITYAI_API_KEY")); + + @Test + void generateImage() throws IOException { + + List textPrompts = List + .of(new StabilityAiApi.GenerateImageRequest.TextPrompts( + "A light cream colored mini golden doodle holding a sign that says 'Heading to BARCADE !'", 0.5f)); + var builder = StabilityAiApi.GenerateImageRequest.builder() + .withTextPrompts(textPrompts) + .withHeight(1024) + .withWidth(1024) + .withCfgScale(7f) + .withSamples(1) + .withSeed(123L) + .withSteps(30) + .withStylePreset("photographic"); + StabilityAiApi.GenerateImageRequest request = builder.build(); + StabilityAiApi.GenerateImageResponse response = stabilityAiApi.generateImage(request); + + assertThat(response).isNotNull(); + List artifacts = response.artifacts(); + writeToFile(artifacts); + assertThat(artifacts).hasSize(1); + var firstArtifact = artifacts.get(0); + assertThat(firstArtifact.base64()).isNotEmpty(); + assertThat(firstArtifact.seed()).isPositive(); + assertThat(firstArtifact.finishReason()).isEqualTo("SUCCESS"); + + } + + private static void writeToFile(List artifacts) throws IOException { + int counter = 0; + String systemTempDir = System.getProperty("java.io.tmpdir"); + for (StabilityAiApi.GenerateImageResponse.Artifacts artifact : artifacts) { + counter++; + byte[] imageBytes = Base64.getDecoder().decode(artifact.base64()); + String fileName = String.format("dog%d.png", counter); + String filePath = systemTempDir + File.separator + fileName; + File file = new File(filePath); + try (FileOutputStream fos = new FileOutputStream(file)) { + fos.write(imageBytes); + } + } + } + +} diff --git a/models/spring-ai-stabilityai/src/test/java/org/springframework/ai/stabilityai/StabilityAiImageClientIT.java b/models/spring-ai-stabilityai/src/test/java/org/springframework/ai/stabilityai/StabilityAiImageClientIT.java new file mode 100644 index 00000000000..ea6f74d5734 --- /dev/null +++ b/models/spring-ai-stabilityai/src/test/java/org/springframework/ai/stabilityai/StabilityAiImageClientIT.java @@ -0,0 +1,50 @@ +package org.springframework.ai.stabilityai; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.image.*; +import org.springframework.ai.stabilityai.api.StabilityAiApi; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; + +import java.io.File; +import java.io.FileNotFoundException; +import java.io.FileOutputStream; +import java.io.IOException; +import java.util.Base64; + +import static org.assertj.core.api.Assertions.assertThat; + +@SpringBootTest(classes = StabilityAiImageTestConfiguration.class) +@EnabledIfEnvironmentVariable(named = "STABILITYAI_API_KEY", matches = ".*") +public class StabilityAiImageClientIT { + + @Autowired + protected ImageClient stabilityAiImageClient; + + @Test + void imageAsBase64Test() throws IOException { + ImagePrompt imagePrompt = new ImagePrompt( + "A light cream colored mini golden doodle holding a sign that says 'I want to go with you on vacation!'"); + + ImageResponse imageResponse = this.stabilityAiImageClient.call(imagePrompt); + + ImageGeneration imageGeneration = imageResponse.getResult(); + Image image = imageGeneration.getOutput(); + + assertThat(image.getB64Json()).isNotEmpty(); + + writeFile(image); + } + + private static void writeFile(Image image) throws IOException { + byte[] imageBytes = Base64.getDecoder().decode(image.getB64Json()); + String systemTempDir = System.getProperty("java.io.tmpdir"); + String filePath = systemTempDir + File.separator + "dog.png"; + File file = new File(filePath); + try (FileOutputStream fos = new FileOutputStream(file)) { + fos.write(imageBytes); + } + } + +} diff --git a/models/spring-ai-stabilityai/src/test/java/org/springframework/ai/stabilityai/StabilityAiImageTestConfiguration.java b/models/spring-ai-stabilityai/src/test/java/org/springframework/ai/stabilityai/StabilityAiImageTestConfiguration.java new file mode 100644 index 00000000000..a4ff9ed0390 --- /dev/null +++ b/models/spring-ai-stabilityai/src/test/java/org/springframework/ai/stabilityai/StabilityAiImageTestConfiguration.java @@ -0,0 +1,30 @@ +package org.springframework.ai.stabilityai; + +import org.springframework.ai.stabilityai.api.StabilityAiApi; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.context.annotation.Bean; +import org.springframework.util.StringUtils; + +@SpringBootConfiguration +public class StabilityAiImageTestConfiguration { + + @Bean + public StabilityAiApi stabilityAiApi() { + return new StabilityAiApi(getApiKey()); + } + + @Bean + StabilityAiImageClient stabilityAiImageClient(StabilityAiApi stabilityAiApi) { + return new StabilityAiImageClient(stabilityAiApi); + } + + private String getApiKey() { + String apiKey = System.getenv("STABILITYAI_API_KEY"); + if (!StringUtils.hasText(apiKey)) { + throw new IllegalArgumentException( + "You must provide an API key. Put it in an environment variable under the name STABILITYAI_API_KEY"); + } + return apiKey; + } + +} diff --git a/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/ResourceCacheService.java b/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/ResourceCacheService.java index b8d7695fe8e..094dd2438e0 100644 --- a/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/ResourceCacheService.java +++ b/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/ResourceCacheService.java @@ -56,7 +56,7 @@ public class ResourceCacheService { private List excludedUriSchemas = new ArrayList<>(List.of("file", "classpath")); public ResourceCacheService() { - this(new File(System.getProperty("java.io.tmpdir"), "spring-ai-onnx-model").getAbsolutePath()); + this(new File(System.getProperty("java.io.tmpdir"), "spring-ai-onnx-generative").getAbsolutePath()); } public ResourceCacheService(String rootCacheDirectory) { diff --git a/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/TransformersEmbeddingClient.java b/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/TransformersEmbeddingClient.java index 66b1dec0f5b..087510f279f 100644 --- a/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/TransformersEmbeddingClient.java +++ b/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/TransformersEmbeddingClient.java @@ -41,10 +41,10 @@ public class TransformersEmbeddingClient extends AbstractEmbeddingClient impleme private static final Log logger = LogFactory.getLog(TransformersEmbeddingClient.class); - // ONNX tokenizer for the all-MiniLM-L6-v2 model + // ONNX tokenizer for the all-MiniLM-L6-v2 generative public final static String DEFAULT_ONNX_TOKENIZER_URI = "https://raw.githubusercontent.com/spring-projects/spring-ai/main/models/spring-ai-transformers/src/main/resources/onnx/all-MiniLM-L6-v2/tokenizer.json"; - // ONNX model for all-MiniLM-L6-v2 pre-trained transformer: + // ONNX generative for all-MiniLM-L6-v2 pre-trained transformer: // https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2 public final static String DEFAULT_ONNX_MODEL_URI = "https://github.com/spring-projects/spring-ai/raw/main/models/spring-ai-transformers/src/main/resources/onnx/all-MiniLM-L6-v2/model.onnx"; @@ -70,7 +70,7 @@ public class TransformersEmbeddingClient extends AbstractEmbeddingClient impleme private OrtEnvironment environment; /** - * Runtime session that wraps the ONNX model and enables inference calls. + * Runtime session that wraps the ONNX generative and enables inference calls. */ private OrtSession session; @@ -181,7 +181,7 @@ public void afterPropertiesSet() throws Exception { logger.info("Model output names: " + onnxModelOutputs.stream().collect(Collectors.joining(", "))); Assert.isTrue(onnxModelOutputs.contains(this.modelOutputName), - "The model output names doesn't contain expected: " + this.modelOutputName); + "The generative output names doesn't contain expected: " + this.modelOutputName); } private Resource getCachedResource(Resource resource) { diff --git a/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/samples/ONNXSample.java b/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/samples/ONNXSample.java index 26f5076e1c2..887f3fe9d46 100644 --- a/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/samples/ONNXSample.java +++ b/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/samples/ONNXSample.java @@ -58,7 +58,7 @@ public static NDArray meanPooling(NDArray tokenEmbeddings, NDArray attentionMask public static void main(String[] args) throws Exception { String TOKENIZER_URI = "classpath:/onnx/tokenizer.json"; - String MODEL_URI = "classpath:/onnx/model.onnx"; + String MODEL_URI = "classpath:/onnx/generative.onnx"; var tokenizerResource = new DefaultResourceLoader().getResource(TOKENIZER_URI); var modelResource = new DefaultResourceLoader().getResource(MODEL_URI); diff --git a/models/spring-ai-vertex-ai/src/main/java/org/springframework/ai/vertex/VertexAiChatClient.java b/models/spring-ai-vertex-ai/src/main/java/org/springframework/ai/vertex/VertexAiChatClient.java index c7ec053c31f..1c922978180 100644 --- a/models/spring-ai-vertex-ai/src/main/java/org/springframework/ai/vertex/VertexAiChatClient.java +++ b/models/spring-ai-vertex-ai/src/main/java/org/springframework/ai/vertex/VertexAiChatClient.java @@ -22,8 +22,8 @@ import org.springframework.ai.chat.ChatClient; import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.Generation; -import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.messages.MessageType; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.vertex.api.VertexAiApi; import org.springframework.ai.vertex.api.VertexAiApi.GenerateMessageRequest; import org.springframework.ai.vertex.api.VertexAiApi.GenerateMessageResponse; @@ -71,15 +71,15 @@ public VertexAiChatClient withCandidateCount(Integer maxTokens) { } @Override - public ChatResponse generate(Prompt prompt) { + public ChatResponse call(Prompt prompt) { - String vertexContext = prompt.getMessages() + String vertexContext = prompt.getInstructions() .stream() .filter(m -> m.getMessageType() == MessageType.SYSTEM) .map(m -> m.getContent()) .collect(Collectors.joining("\n")); - List vertexMessages = prompt.getMessages() + List vertexMessages = prompt.getInstructions() .stream() .filter(m -> m.getMessageType() == MessageType.USER || m.getMessageType() == MessageType.ASSISTANT) .map(m -> new VertexAiApi.Message(m.getMessageType().getValue(), m.getContent())) diff --git a/models/spring-ai-vertex-ai/src/main/java/org/springframework/ai/vertex/api/VertexAiApi.java b/models/spring-ai-vertex-ai/src/main/java/org/springframework/ai/vertex/api/VertexAiApi.java index 9e9609dd20f..811be1fdc91 100644 --- a/models/spring-ai-vertex-ai/src/main/java/org/springframework/ai/vertex/api/VertexAiApi.java +++ b/models/spring-ai-vertex-ai/src/main/java/org/springframework/ai/vertex/api/VertexAiApi.java @@ -112,7 +112,7 @@ public class VertexAiApi { private final String embeddingModel; /** - * Create an new chat completion api. + * Create a new chat completion api. * @param apiKey vertex apiKey. */ public VertexAiApi(String apiKey) { @@ -120,7 +120,7 @@ public VertexAiApi(String apiKey) { } /** - * Create an new chat completion api. + * Create a new chat completion api. * @param baseUrl api base URL. * @param apiKey vertex apiKey. * @param model vertex model. diff --git a/models/spring-ai-vertex-ai/src/test/java/org/springframework/ai/vertex/api/VertexAiApiTests.java b/models/spring-ai-vertex-ai/src/test/java/org/springframework/ai/vertex/api/VertexAiApiTests.java index 78645d27a2a..ba6182f652b 100644 --- a/models/spring-ai-vertex-ai/src/test/java/org/springframework/ai/vertex/api/VertexAiApiTests.java +++ b/models/spring-ai-vertex-ai/src/test/java/org/springframework/ai/vertex/api/VertexAiApiTests.java @@ -78,7 +78,7 @@ public void generateMessage() throws JsonProcessingException { List.of(new VertexAiApi.GenerateMessageResponse.ContentFilter(BlockedReason.SAFETY, "reason"))); server - .expect(requestToUriTemplate("/models/{model}:generateMessage?key={apiKey}", + .expect(requestToUriTemplate("/models/{generative}:generateMessage?key={apiKey}", VertexAiApi.DEFAULT_GENERATE_MODEL, TEST_API_KEY)) .andExpect(method(HttpMethod.POST)) .andExpect(content().json(objectMapper.writeValueAsString(request))) @@ -99,8 +99,8 @@ public void embedText() throws JsonProcessingException { Embedding expectedEmbedding = new Embedding(List.of(0.1, 0.2, 0.3)); server - .expect(requestToUriTemplate("/models/{model}:embedText?key={apiKey}", VertexAiApi.DEFAULT_EMBEDDING_MODEL, - TEST_API_KEY)) + .expect(requestToUriTemplate("/models/{generative}:embedText?key={apiKey}", + VertexAiApi.DEFAULT_EMBEDDING_MODEL, TEST_API_KEY)) .andExpect(method(HttpMethod.POST)) .andExpect(content().json(objectMapper.writeValueAsString(Map.of("text", text)))) .andRespond(withSuccess(objectMapper.writeValueAsString(Map.of("embedding", expectedEmbedding)), @@ -122,7 +122,7 @@ public void batchEmbedText() throws JsonProcessingException { new Embedding(List.of(0.4, 0.5, 0.6))); server - .expect(requestToUriTemplate("/models/{model}:batchEmbedText?key={apiKey}", + .expect(requestToUriTemplate("/models/{generative}:batchEmbedText?key={apiKey}", VertexAiApi.DEFAULT_EMBEDDING_MODEL, TEST_API_KEY)) .andExpect(method(HttpMethod.POST)) .andExpect(content().json(objectMapper.writeValueAsString(Map.of("texts", texts)))) diff --git a/models/spring-ai-vertex-ai/src/test/java/org/springframework/ai/vertex/generation/VertexAiChatGenerationClientIT.java b/models/spring-ai-vertex-ai/src/test/java/org/springframework/ai/vertex/generation/VertexAiChatGenerationClientIT.java index 69bd76d686e..90bc23e8e8a 100644 --- a/models/spring-ai-vertex-ai/src/test/java/org/springframework/ai/vertex/generation/VertexAiChatGenerationClientIT.java +++ b/models/spring-ai-vertex-ai/src/test/java/org/springframework/ai/vertex/generation/VertexAiChatGenerationClientIT.java @@ -12,11 +12,11 @@ import org.springframework.ai.parser.BeanOutputParser; import org.springframework.ai.parser.ListOutputParser; import org.springframework.ai.parser.MapOutputParser; -import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.PromptTemplate; -import org.springframework.ai.prompt.SystemPromptTemplate; -import org.springframework.ai.prompt.messages.Message; -import org.springframework.ai.prompt.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.vertex.VertexAiChatClient; import org.springframework.ai.vertex.api.VertexAiApi; import org.springframework.beans.factory.annotation.Autowired; @@ -48,8 +48,8 @@ void roleTest() { SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", name, "voice", voice)); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - ChatResponse response = client.generate(prompt); - assertThat(response.getGeneration().getContent()).contains("Bartholomew"); + ChatResponse response = client.call(prompt); + assertThat(response.getResult().getOutput().getContent()).contains("Bartholomew"); } // @Test @@ -65,9 +65,9 @@ void outputParser() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "ice cream flavors.", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = this.client.generate(prompt).getGeneration(); + Generation generation = this.client.call(prompt).getResult(); - List list = outputParser.parse(generation.getContent()); + List list = outputParser.parse(generation.getOutput().getContent()); assertThat(list).hasSize(5); } @@ -84,9 +84,9 @@ void mapOutputParser() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = client.generate(prompt).getGeneration(); + Generation generation = client.call(prompt).getResult(); - Map result = outputParser.parse(generation.getContent()); + Map result = outputParser.parse(generation.getOutput().getContent()); assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); } @@ -106,9 +106,9 @@ void beanOutputParserRecords() { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = client.generate(prompt).getGeneration(); + Generation generation = client.call(prompt).getResult(); - ActorsFilmsRecord actorsFilms = outputParser.parse(generation.getContent()); + ActorsFilmsRecord actorsFilms = outputParser.parse(generation.getOutput().getContent()); assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); assertThat(actorsFilms.movies()).hasSize(5); } diff --git a/pom.xml b/pom.xml index 338ef2c7522..09e012bc6d8 100644 --- a/pom.xml +++ b/pom.xml @@ -14,14 +14,15 @@ spring-ai-core + models/spring-ai-transformers + models/spring-ai-postgresml models/spring-ai-bedrock models/spring-ai-azure-openai models/spring-ai-huggingface models/spring-ai-ollama models/spring-ai-openai models/spring-ai-vertex-ai - models/spring-ai-transformers - models/spring-ai-postgresml + models/spring-ai-stabilityai spring-ai-test spring-ai-spring-boot-autoconfigure spring-ai-spring-boot-starters/spring-ai-starter-openai diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatClient.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatClient.java index 373339467d5..89162ab2d4b 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatClient.java @@ -16,17 +16,18 @@ package org.springframework.ai.chat; -import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.model.ModelClient; @FunctionalInterface -public interface ChatClient { +public interface ChatClient extends ModelClient { - default String generate(String message) { + default String call(String message) { Prompt prompt = new Prompt(new UserMessage(message)); - return generate(prompt).getGeneration().getContent(); + return call(prompt).getResult().getOutput().getContent(); } - ChatResponse generate(Prompt prompt); + ChatResponse call(Prompt prompt); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatOptions.java new file mode 100644 index 00000000000..84d9f2465d2 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatOptions.java @@ -0,0 +1,24 @@ +package org.springframework.ai.chat; + +import org.springframework.ai.model.ModelOptions; + +/** + * portable options + */ +public interface ChatOptions extends ModelOptions { + + // determine portable optionsb + + Float getTemperature(); + + void setTemperature(Float temperature); + + Float getTopP(); + + void setTopP(Float topP); + + Integer getTopK(); + + void setTopK(Integer topK); + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatOptionsBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatOptionsBuilder.java new file mode 100644 index 00000000000..33866b3bf02 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatOptionsBuilder.java @@ -0,0 +1,73 @@ +package org.springframework.ai.chat; + +public class ChatOptionsBuilder { + + private class ChatOptionsImpl implements ChatOptions { + + private Float temperature; + + private Float topP; + + private Integer topK; + + @Override + public Float getTemperature() { + return temperature; + } + + @Override + public void setTemperature(Float temperature) { + this.temperature = temperature; + } + + @Override + public Float getTopP() { + return topP; + } + + @Override + public void setTopP(Float topP) { + this.topP = topP; + } + + @Override + public Integer getTopK() { + return topK; + } + + @Override + public void setTopK(Integer topK) { + this.topK = topK; + } + + } + + private final ChatOptionsImpl options = new ChatOptionsImpl(); + + private ChatOptionsBuilder() { + } + + public static ChatOptionsBuilder builder() { + return new ChatOptionsBuilder(); + } + + public ChatOptionsBuilder withTemperature(Float temperature) { + options.setTemperature(temperature); + return this; + } + + public ChatOptionsBuilder withTopP(Float topP) { + options.setTopP(topP); + return this; + } + + public ChatOptionsBuilder withTopK(Integer topK) { + options.setTopK(topK); + return this; + } + + public ChatOptions build() { + return options; + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatResponse.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatResponse.java index f68f6bb317a..6ad2dab3e8c 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatResponse.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatResponse.java @@ -17,42 +17,39 @@ import java.util.List; -import org.springframework.ai.metadata.GenerationMetadata; -import org.springframework.ai.metadata.PromptMetadata; -import org.springframework.lang.Nullable; +import org.springframework.ai.model.ModelResponse; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; /** * The chat completion (e.g. generation) response returned by an AI provider. */ -public class ChatResponse { +public class ChatResponse implements ModelResponse { - private final GenerationMetadata metadata; + private final ChatResponseMetadata chatResponseMetadata; /** * List of generated messages returned by the AI provider. */ private final List generations; - private PromptMetadata promptMetadata; - /** * Construct a new {@link ChatResponse} instance without metadata. * @param generations the {@link List} of {@link Generation} returned by the AI * provider. */ public ChatResponse(List generations) { - this(generations, GenerationMetadata.NULL); + this(generations, ChatResponseMetadata.NULL); } /** * Construct a new {@link ChatResponse} instance. * @param generations the {@link List} of {@link Generation} returned by the AI * provider. - * @param metadata {@link GenerationMetadata} containing information about the use of - * the AI provider's API. + * @param chatResponseMetadata {@link ChatResponseMetadata} containing information + * about the use of the AI provider's API. */ - public ChatResponse(List generations, GenerationMetadata metadata) { - this.metadata = metadata; + public ChatResponse(List generations, ChatResponseMetadata chatResponseMetadata) { + this.chatResponseMetadata = chatResponseMetadata; this.generations = List.copyOf(generations); } @@ -63,45 +60,25 @@ public ChatResponse(List generations, GenerationMetadata metadata) { * multiple output {@link Generation generations}. * @return the {@link List} of {@link Generation generated outputs}. */ - public List getGenerations() { + + @Override + public List getResults() { return this.generations; } /** * @return Returns the first {@link Generation} in the generations list. */ - public Generation getGeneration() { + public Generation getResult() { return this.generations.get(0); } /** - * @return Returns {@link GenerationMetadata} containing information about the use of - * the AI provider's API. - */ - public GenerationMetadata getGenerationMetadata() { - return this.metadata; - } - - /** - * @return {@link PromptMetadata} containing information on prompt processing by the - * AI. - */ - public PromptMetadata getPromptMetadata() { - PromptMetadata promptMetadata = this.promptMetadata; - return promptMetadata != null ? promptMetadata : PromptMetadata.empty(); - } - - /** - * Builder method used to include {@link PromptMetadata} returned in the AI response - * when processing the prompt. - * @param promptMetadata {@link PromptMetadata} returned by the AI in the response - * when processing the prompt. - * @return this {@link ChatResponse}. - * @see #getPromptMetadata() + * @return Returns {@link ChatResponseMetadata} containing information about the use + * of the AI provider's API. */ - public ChatResponse withPromptMetadata(@Nullable PromptMetadata promptMetadata) { - this.promptMetadata = promptMetadata; - return this; + public ChatResponseMetadata getMetadata() { + return this.chatResponseMetadata; } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/Generation.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/Generation.java index 71fd3455a45..5df07190ff1 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/Generation.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/Generation.java @@ -16,46 +16,65 @@ package org.springframework.ai.chat; -import java.util.Collections; import java.util.Map; +import java.util.Objects; -import org.springframework.ai.metadata.ChoiceMetadata; -import org.springframework.ai.prompt.messages.AbstractMessage; -import org.springframework.ai.prompt.messages.MessageType; +import org.springframework.ai.model.ModelResult; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; +import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.lang.Nullable; /** * Represents a response returned by the AI. */ -public class Generation extends AbstractMessage { +public class Generation implements ModelResult { - private ChoiceMetadata choiceMetadata; + private AssistantMessage assistantMessage; + + private ChatGenerationMetadata chatGenerationMetadata; public Generation(String text) { - this(text, Collections.emptyMap()); + this.assistantMessage = new AssistantMessage(text); } - public Generation(String content, Map properties) { - super(MessageType.ASSISTANT, content, properties); + public Generation(String text, Map properties) { + this.assistantMessage = new AssistantMessage(text, properties); } - public Generation(String content, Map properties, MessageType type) { - super(type, content, properties); + @Override + public AssistantMessage getOutput() { + return this.assistantMessage; } - public ChoiceMetadata getChoiceMetadata() { - ChoiceMetadata choiceMetadata = this.choiceMetadata; - return choiceMetadata != null ? choiceMetadata : ChoiceMetadata.NULL; + public ChatGenerationMetadata getResultMetadata() { + ChatGenerationMetadata chatGenerationMetadata = this.chatGenerationMetadata; + return chatGenerationMetadata != null ? chatGenerationMetadata : ChatGenerationMetadata.NULL; } - public Generation withChoiceMetadata(@Nullable ChoiceMetadata choiceMetadata) { - this.choiceMetadata = choiceMetadata; + public Generation withGenerationMetadata(@Nullable ChatGenerationMetadata chatGenerationMetadata) { + this.chatGenerationMetadata = chatGenerationMetadata; return this; } + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (!(o instanceof Generation that)) + return false; + return Objects.equals(assistantMessage, that.assistantMessage) + && Objects.equals(chatGenerationMetadata, that.chatGenerationMetadata); + } + + @Override + public int hashCode() { + return Objects.hash(assistantMessage, chatGenerationMetadata); + } + @Override public String toString() { - return "Generation{" + "text='" + content + '\'' + ", info=" + properties + '}'; + return "Generation{" + "assistantMessage=" + assistantMessage + ", chatGenerationMetadata=" + + chatGenerationMetadata + '}'; } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/StreamingChatClient.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/StreamingChatClient.java index d697636151a..beb67ae6a3c 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/StreamingChatClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/StreamingChatClient.java @@ -18,7 +18,7 @@ import reactor.core.publisher.Flux; -import org.springframework.ai.prompt.Prompt; +import org.springframework.ai.chat.prompt.Prompt; @FunctionalInterface public interface StreamingChatClient { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/AbstractMessage.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/AbstractMessage.java similarity index 97% rename from spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/AbstractMessage.java rename to spring-ai-core/src/main/java/org/springframework/ai/chat/messages/AbstractMessage.java index 9f4c123ecd0..24ecf3c1e06 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/AbstractMessage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/AbstractMessage.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.ai.prompt.messages; +package org.springframework.ai.chat.messages; import org.springframework.core.io.Resource; import org.springframework.util.StreamUtils; @@ -31,7 +31,7 @@ public abstract class AbstractMessage implements Message { protected String content; /** - * Additional options for the message to influence the response, not a model map. + * Additional options for the message to influence the response, not a generative map. */ protected Map properties = new HashMap<>(); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/AssistantMessage.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/AssistantMessage.java similarity index 82% rename from spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/AssistantMessage.java rename to spring-ai-core/src/main/java/org/springframework/ai/chat/messages/AssistantMessage.java index 2e064bd65f8..02e29146953 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/AssistantMessage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/AssistantMessage.java @@ -14,14 +14,14 @@ * limitations under the License. */ -package org.springframework.ai.prompt.messages; +package org.springframework.ai.chat.messages; import java.util.Map; /** - * Lets the model know the content was generated as a response to the user. This role - * indicates messages that the model has previously generated in the conversation. By - * including assistant messages in the series, you provide context to the model about + * Lets the generative know the content was generated as a response to the user. This role + * indicates messages that the generative has previously generated in the conversation. By + * including assistant messages in the series, you provide context to the generative about * prior exchanges in the conversation. */ public class AssistantMessage extends AbstractMessage { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/ChatMessage.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/ChatMessage.java similarity index 96% rename from spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/ChatMessage.java rename to spring-ai-core/src/main/java/org/springframework/ai/chat/messages/ChatMessage.java index 6dc660ccb27..194aa54afaf 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/ChatMessage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/ChatMessage.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.ai.prompt.messages; +package org.springframework.ai.chat.messages; import java.util.Map; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/FunctionMessage.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/FunctionMessage.java similarity index 95% rename from spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/FunctionMessage.java rename to spring-ai-core/src/main/java/org/springframework/ai/chat/messages/FunctionMessage.java index c00520dcdda..5fa90f50be0 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/FunctionMessage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/FunctionMessage.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.ai.prompt.messages; +package org.springframework.ai.chat.messages; import java.util.Map; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/Message.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/Message.java similarity index 94% rename from spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/Message.java rename to spring-ai-core/src/main/java/org/springframework/ai/chat/messages/Message.java index a405eb10b81..e1e6eda348c 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/Message.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/Message.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.ai.prompt.messages; +package org.springframework.ai.chat.messages; import java.util.Map; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/MessageType.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/MessageType.java similarity index 95% rename from spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/MessageType.java rename to spring-ai-core/src/main/java/org/springframework/ai/chat/messages/MessageType.java index 267727c7c00..6cb07a045ba 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/MessageType.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/MessageType.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.prompt.messages; +package org.springframework.ai.chat.messages; public enum MessageType { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/SystemMessage.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/SystemMessage.java similarity index 88% rename from spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/SystemMessage.java rename to spring-ai-core/src/main/java/org/springframework/ai/chat/messages/SystemMessage.java index 31179045122..ac70e593238 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/SystemMessage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/SystemMessage.java @@ -14,15 +14,16 @@ * limitations under the License. */ -package org.springframework.ai.prompt.messages; +package org.springframework.ai.chat.messages; import org.springframework.core.io.Resource; /** * A message of the type 'system' passed as input. The system message gives high level * instructions for the conversation. This role typically provides high-level instructions - * for the conversation. For example, you might use a system message to instruct the model - * to behave like a certain character or to provide answers in a specific format. + * for the conversation. For example, you might use a system message to instruct the + * generative to behave like a certain character or to provide answers in a specific + * format. */ public class SystemMessage extends AbstractMessage { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/UserMessage.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/UserMessage.java similarity index 93% rename from spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/UserMessage.java rename to spring-ai-core/src/main/java/org/springframework/ai/chat/messages/UserMessage.java index 8c7fdcae6fa..6b91cc1d2e1 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/UserMessage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/UserMessage.java @@ -14,14 +14,14 @@ * limitations under the License. */ -package org.springframework.ai.prompt.messages; +package org.springframework.ai.chat.messages; import org.springframework.core.io.Resource; /** * A message of the type 'user' passed as input Messages with the user role are from the * end-user or developer. They represent questions, prompts, or any input that you want - * the model to respond to. + * the generative to respond to. */ public class UserMessage extends AbstractMessage { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/metadata/AbstractRateLimit.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/AbstractRateLimit.java similarity index 96% rename from spring-ai-core/src/main/java/org/springframework/ai/metadata/AbstractRateLimit.java rename to spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/AbstractRateLimit.java index 9b0b3a80315..f485c7a294a 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/metadata/AbstractRateLimit.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/AbstractRateLimit.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.ai.metadata; +package org.springframework.ai.chat.metadata; import java.time.Duration; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/metadata/AbstractUsage.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/AbstractUsage.java similarity index 95% rename from spring-ai-core/src/main/java/org/springframework/ai/metadata/AbstractUsage.java rename to spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/AbstractUsage.java index 2adad4b9de3..edec477abac 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/metadata/AbstractUsage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/AbstractUsage.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.ai.metadata; +package org.springframework.ai.chat.metadata; /** * Abstract base class used as a foundation for implementing {@link Usage}. diff --git a/spring-ai-core/src/main/java/org/springframework/ai/metadata/ChoiceMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/ChatGenerationMetadata.java similarity index 73% rename from spring-ai-core/src/main/java/org/springframework/ai/metadata/ChoiceMetadata.java rename to spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/ChatGenerationMetadata.java index 9c15fbdb226..d9f5fc56eb7 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/metadata/ChoiceMetadata.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/ChatGenerationMetadata.java @@ -14,8 +14,9 @@ * limitations under the License. */ -package org.springframework.ai.metadata; +package org.springframework.ai.chat.metadata; +import org.springframework.ai.model.ResultMetadata; import org.springframework.lang.Nullable; /** @@ -25,21 +26,21 @@ * @author John Blum * @since 0.7.0 */ -public interface ChoiceMetadata { +public interface ChatGenerationMetadata extends ResultMetadata { - ChoiceMetadata NULL = ChoiceMetadata.from(null, null); + ChatGenerationMetadata NULL = ChatGenerationMetadata.from(null, null); /** - * Factory method used to construct a new {@link ChoiceMetadata} from the given - * {@link String finish reason} and content filter metadata. + * Factory method used to construct a new {@link ChatGenerationMetadata} from the + * given {@link String finish reason} and content filter metadata. * @param finishReason {@link String} contain the reason for the choice completion. * @param contentFilterMetadata underlying AI provider metadata for filtering applied * to generation content. - * @return a new {@link ChoiceMetadata} from the given {@link String finish reason} - * and content filter metadata. + * @return a new {@link ChatGenerationMetadata} from the given {@link String finish + * reason} and content filter metadata. */ - static ChoiceMetadata from(String finishReason, Object contentFilterMetadata) { - return new ChoiceMetadata() { + static ChatGenerationMetadata from(String finishReason, Object contentFilterMetadata) { + return new ChatGenerationMetadata() { @Override @SuppressWarnings("unchecked") diff --git a/spring-ai-core/src/main/java/org/springframework/ai/metadata/GenerationMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/ChatResponseMetadata.java similarity index 79% rename from spring-ai-core/src/main/java/org/springframework/ai/metadata/GenerationMetadata.java rename to spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/ChatResponseMetadata.java index 3bd641eca24..12c9d0b5c8f 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/metadata/GenerationMetadata.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/ChatResponseMetadata.java @@ -14,7 +14,9 @@ * limitations under the License. */ -package org.springframework.ai.metadata; +package org.springframework.ai.chat.metadata; + +import org.springframework.ai.model.ResponseMetadata; /** * Abstract Data Type (ADT) modeling common AI provider metadata returned in an AI @@ -23,9 +25,9 @@ * @author John Blum * @since 0.7.0 */ -public interface GenerationMetadata { +public interface ChatResponseMetadata extends ResponseMetadata { - GenerationMetadata NULL = new GenerationMetadata() { + ChatResponseMetadata NULL = new ChatResponseMetadata() { }; /** @@ -46,4 +48,8 @@ default Usage getUsage() { return Usage.NULL; } + default PromptMetadata getPromptMetadata() { + return PromptMetadata.empty(); + } + } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/metadata/PromptMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/PromptMetadata.java similarity index 98% rename from spring-ai-core/src/main/java/org/springframework/ai/metadata/PromptMetadata.java rename to spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/PromptMetadata.java index 302481b6e38..6ecb0924603 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/metadata/PromptMetadata.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/PromptMetadata.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.metadata; +package org.springframework.ai.chat.metadata; import java.util.Arrays; import java.util.Optional; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/metadata/RateLimit.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/RateLimit.java similarity index 98% rename from spring-ai-core/src/main/java/org/springframework/ai/metadata/RateLimit.java rename to spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/RateLimit.java index 40bb933d03b..66bdebd2b40 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/metadata/RateLimit.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/RateLimit.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.ai.metadata; +package org.springframework.ai.chat.metadata; import java.time.Duration; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/metadata/Usage.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/Usage.java similarity index 97% rename from spring-ai-core/src/main/java/org/springframework/ai/metadata/Usage.java rename to spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/Usage.java index 7179d5e127e..9cc1dbb711a 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/metadata/Usage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/Usage.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.ai.metadata; +package org.springframework.ai.chat.metadata; /** * Abstract Data Type (ADT) encapsulating metadata on the usage of an AI provider's API diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/package-info.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/package-info.java new file mode 100644 index 00000000000..98d92eb7127 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/package-info.java @@ -0,0 +1,14 @@ +/** + * The org.sf.ai.chat package represents the bounded context for the Chat Model within the + * AI generative model domain. This package extends the core domain defined in + * org.sf.ai.generative, providing implementations specific to chat-based generative AI + * interactions. + * + * In line with Domain-Driven Design principles, this package includes implementations of + * entities and value objects specific to the chat context, such as ChatPrompt and + * ChatResponse, adhering to the ubiquitous language of chat interactions in AI models. + * + * This bounded context is designed to encapsulate all aspects of chat-based AI + * functionalities, maintaining a clear boundary from other contexts within the AI domain. + */ +package org.springframework.ai.chat; \ No newline at end of file diff --git a/spring-ai-core/src/main/java/org/springframework/ai/prompt/AssistantPromptTemplate.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/AssistantPromptTemplate.java similarity index 90% rename from spring-ai-core/src/main/java/org/springframework/ai/prompt/AssistantPromptTemplate.java rename to spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/AssistantPromptTemplate.java index 715763f47fe..8a167f00590 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/prompt/AssistantPromptTemplate.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/AssistantPromptTemplate.java @@ -14,9 +14,9 @@ * limitations under the License. */ -package org.springframework.ai.prompt; +package org.springframework.ai.chat.prompt; -import org.springframework.ai.prompt.messages.AssistantMessage; +import org.springframework.ai.chat.messages.AssistantMessage; import java.util.Map; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/prompt/ChatPromptTemplate.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatPromptTemplate.java similarity index 95% rename from spring-ai-core/src/main/java/org/springframework/ai/prompt/ChatPromptTemplate.java rename to spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatPromptTemplate.java index dbef2e4eab9..800c4586ecb 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/prompt/ChatPromptTemplate.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatPromptTemplate.java @@ -14,9 +14,9 @@ * limitations under the License. */ -package org.springframework.ai.prompt; +package org.springframework.ai.chat.prompt; -import org.springframework.ai.prompt.messages.Message; +import org.springframework.ai.chat.messages.Message; import java.util.ArrayList; import java.util.List; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/prompt/FunctionPromptTemplate.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/FunctionPromptTemplate.java similarity index 94% rename from spring-ai-core/src/main/java/org/springframework/ai/prompt/FunctionPromptTemplate.java rename to spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/FunctionPromptTemplate.java index 41504364ffd..4c7ce981ffc 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/prompt/FunctionPromptTemplate.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/FunctionPromptTemplate.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.ai.prompt; +package org.springframework.ai.chat.prompt; public class FunctionPromptTemplate extends PromptTemplate { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/prompt/Prompt.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/Prompt.java similarity index 53% rename from spring-ai-core/src/main/java/org/springframework/ai/prompt/Prompt.java rename to spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/Prompt.java index b0acf5fb4c4..29f671e7b57 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/prompt/Prompt.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/Prompt.java @@ -14,19 +14,23 @@ * limitations under the License. */ -package org.springframework.ai.prompt; +package org.springframework.ai.chat.prompt; -import org.springframework.ai.prompt.messages.Message; -import org.springframework.ai.prompt.messages.UserMessage; +import org.springframework.ai.model.ModelOptions; +import org.springframework.ai.model.ModelRequest; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; import java.util.Collections; import java.util.List; import java.util.Objects; -public class Prompt { +public class Prompt implements ModelRequest> { private final List messages; + private ModelOptions modelOptions; + public Prompt(String contents) { this(new UserMessage(contents)); } @@ -39,36 +43,53 @@ public Prompt(List messages) { this.messages = messages; } + public Prompt(String contents, ModelOptions modelOptions) { + this(new UserMessage(contents), modelOptions); + } + + public Prompt(Message message, ModelOptions modelOptions) { + this(Collections.singletonList(message), modelOptions); + } + + public Prompt(List messages, ModelOptions modelOptions) { + this.messages = messages; + this.modelOptions = modelOptions; + } + public String getContents() { StringBuilder sb = new StringBuilder(); - for (Message message : getMessages()) { + for (Message message : getInstructions()) { sb.append(message.getContent()); } return sb.toString(); } - public List getMessages() { + public ModelOptions getOptions() { + return modelOptions; + } + + @Override + public List getInstructions() { return this.messages; } @Override public String toString() { - return "Prompt{" + "messages=" + messages + '}'; + return "Prompt{" + "messages=" + messages + ", modelOptions=" + modelOptions + '}'; } @Override public boolean equals(Object o) { if (this == o) return true; - if (o == null || getClass() != o.getClass()) + if (!(o instanceof Prompt prompt)) return false; - Prompt prompt = (Prompt) o; - return Objects.equals(messages, prompt.messages); + return Objects.equals(messages, prompt.messages) && Objects.equals(modelOptions, prompt.modelOptions); } @Override public int hashCode() { - return Objects.hash(messages); + return Objects.hash(messages, modelOptions); } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/prompt/PromptTemplate.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplate.java similarity index 96% rename from spring-ai-core/src/main/java/org/springframework/ai/prompt/PromptTemplate.java rename to spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplate.java index 6207aefeadf..4ac2a1b92d9 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/prompt/PromptTemplate.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplate.java @@ -14,13 +14,13 @@ * limitations under the License. */ -package org.springframework.ai.prompt; +package org.springframework.ai.chat.prompt; import org.antlr.runtime.Token; import org.antlr.runtime.TokenStream; import org.springframework.ai.parser.OutputParser; -import org.springframework.ai.prompt.messages.Message; -import org.springframework.ai.prompt.messages.UserMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.core.io.Resource; import org.springframework.util.StreamUtils; import org.stringtemplate.v4.ST; @@ -189,7 +189,7 @@ public Prompt create(Map model) { return new Prompt(render(model)); } - protected Set getInputVariables() { + public Set getInputVariables() { TokenStream tokens = this.st.impl.tokens; return IntStream.range(0, tokens.range()) .mapToObj(tokens::get) diff --git a/spring-ai-core/src/main/java/org/springframework/ai/prompt/PromptTemplateActions.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateActions.java similarity index 94% rename from spring-ai-core/src/main/java/org/springframework/ai/prompt/PromptTemplateActions.java rename to spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateActions.java index e1637384f88..3381eaddd50 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/prompt/PromptTemplateActions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateActions.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.ai.prompt; +package org.springframework.ai.chat.prompt; import java.util.Map; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/prompt/PromptTemplateChatActions.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateChatActions.java similarity index 66% rename from spring-ai-core/src/main/java/org/springframework/ai/prompt/PromptTemplateChatActions.java rename to spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateChatActions.java index b12384c65ef..361de4ade93 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/prompt/PromptTemplateChatActions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateChatActions.java @@ -1,6 +1,6 @@ -package org.springframework.ai.prompt; +package org.springframework.ai.chat.prompt; -import org.springframework.ai.prompt.messages.Message; +import org.springframework.ai.chat.messages.Message; import java.util.List; import java.util.Map; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/prompt/PromptTemplateMessageActions.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateMessageActions.java similarity index 61% rename from spring-ai-core/src/main/java/org/springframework/ai/prompt/PromptTemplateMessageActions.java rename to spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateMessageActions.java index bf48bf5460d..47e24620c10 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/prompt/PromptTemplateMessageActions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateMessageActions.java @@ -1,6 +1,6 @@ -package org.springframework.ai.prompt; +package org.springframework.ai.chat.prompt; -import org.springframework.ai.prompt.messages.Message; +import org.springframework.ai.chat.messages.Message; import java.util.Map; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/prompt/PromptTemplateStringActions.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateStringActions.java similarity index 75% rename from spring-ai-core/src/main/java/org/springframework/ai/prompt/PromptTemplateStringActions.java rename to spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateStringActions.java index 8c041e81c13..86495a49e3e 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/prompt/PromptTemplateStringActions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateStringActions.java @@ -1,4 +1,4 @@ -package org.springframework.ai.prompt; +package org.springframework.ai.chat.prompt; import java.util.Map; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/prompt/SystemPromptTemplate.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/SystemPromptTemplate.java similarity index 89% rename from spring-ai-core/src/main/java/org/springframework/ai/prompt/SystemPromptTemplate.java rename to spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/SystemPromptTemplate.java index 73d84902c98..539287d070b 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/prompt/SystemPromptTemplate.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/SystemPromptTemplate.java @@ -14,10 +14,10 @@ * limitations under the License. */ -package org.springframework.ai.prompt; +package org.springframework.ai.chat.prompt; -import org.springframework.ai.prompt.messages.Message; -import org.springframework.ai.prompt.messages.SystemMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.core.io.Resource; import java.util.Map; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/prompt/TemplateFormat.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/TemplateFormat.java similarity index 96% rename from spring-ai-core/src/main/java/org/springframework/ai/prompt/TemplateFormat.java rename to spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/TemplateFormat.java index 9782db6b3b4..1001ee26600 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/prompt/TemplateFormat.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/TemplateFormat.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.ai.prompt; +package org.springframework.ai.chat.prompt; public enum TemplateFormat { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/document/DefaultContentFormatter.java b/spring-ai-core/src/main/java/org/springframework/ai/document/DefaultContentFormatter.java index fcc7a570329..f02a2d0b5d0 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/document/DefaultContentFormatter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/document/DefaultContentFormatter.java @@ -71,7 +71,7 @@ public class DefaultContentFormatter implements ContentFormatter { private final List excludedInferenceMetadataKeys; /** - * Metadata keys that are excluded from text for the embed model. + * Metadata keys that are excluded from text for the embed generative. */ private final List excludedEmbedMetadataKeys; @@ -157,7 +157,8 @@ public Builder withTextTemplate(String textTemplate) { } /** - * Configures the excluded Inference metadata keys to filter out from the model. + * Configures the excluded Inference metadata keys to filter out from the + * generative. * @param excludedInferenceMetadataKeys Excluded inference metadata keys to use. * @return this builder */ @@ -174,7 +175,7 @@ public Builder withExcludedInferenceMetadataKeys(String... keys) { } /** - * Configures the excluded Embed metadata keys to filter out from the model. + * Configures the excluded Embed metadata keys to filter out from the generative. * @param excludedEmbedMetadataKeys Excluded Embed metadata keys to use. * @return this builder */ diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingClient.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingClient.java index b0bd551c7ff..fdc519ad667 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingClient.java @@ -38,7 +38,8 @@ public interface EmbeddingClient { EmbeddingResponse embedForResponse(List texts); /** - * @return the number of dimensions of the embedded vectors. It is model specific. + * @return the number of dimensions of the embedded vectors. It is generative + * specific. */ default int dimensions() { return embed("Test String").size(); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingUtil.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingUtil.java index 451cea484f1..91d621b0caf 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingUtil.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingUtil.java @@ -31,11 +31,11 @@ public class EmbeddingUtil { private static Map KNOWN_EMBEDDING_DIMENSIONS = loadKnownModelDimensions(); /** - * Return the dimension of the requested embedding model name. If the model name is - * unknown uses the EmbeddingClient to perform a dummy EmbeddingClient#embed and count - * the response dimensions. + * Return the dimension of the requested embedding generative name. If the generative + * name is unknown uses the EmbeddingClient to perform a dummy EmbeddingClient#embed + * and count the response dimensions. * @param embeddingClient Fall-back client to determine, empirically the dimensions. - * @param modelName Embedding model name to retrieve the dimensions for. + * @param modelName Embedding generative name to retrieve the dimensions for. * @param dummyContent Dummy content to use for the empirical dimension calculation. * @return Returns the embedding dimensions for the modelName. */ diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/Image.java b/spring-ai-core/src/main/java/org/springframework/ai/image/Image.java new file mode 100644 index 00000000000..561d7398a37 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/Image.java @@ -0,0 +1,57 @@ +package org.springframework.ai.image; + +import java.util.Objects; + +public class Image { + + /** + * The URL where the image can be accessed. + */ + private String url; + + /** + * Base64 encoded image string. + */ + private String b64Json; + + public Image(String url, String b64Json) { + this.url = url; + this.b64Json = b64Json; + } + + public String getUrl() { + return url; + } + + public void setUrl(String url) { + this.url = url; + } + + public String getB64Json() { + return b64Json; + } + + public void setB64Json(String b64Json) { + this.b64Json = b64Json; + } + + @Override + public String toString() { + return "Image{" + "url='" + url + '\'' + ", b64Json='" + b64Json + '\'' + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (!(o instanceof Image image)) + return false; + return Objects.equals(url, image.url) && Objects.equals(b64Json, image.b64Json); + } + + @Override + public int hashCode() { + return Objects.hash(url, b64Json); + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageClient.java b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageClient.java new file mode 100644 index 00000000000..c12db623ef0 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageClient.java @@ -0,0 +1,10 @@ +package org.springframework.ai.image; + +import org.springframework.ai.model.ModelClient; + +@FunctionalInterface +public interface ImageClient extends ModelClient { + + ImageResponse call(ImagePrompt request); + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageGeneration.java b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageGeneration.java new file mode 100644 index 00000000000..98cb0654fbf --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageGeneration.java @@ -0,0 +1,35 @@ +package org.springframework.ai.image; + +import org.springframework.ai.model.ModelResult; + +public class ImageGeneration implements ModelResult { + + private ImageGenerationMetadata imageGenerationMetadata; + + private Image image; + + public ImageGeneration(Image image) { + this.image = image; + } + + public ImageGeneration(Image image, ImageGenerationMetadata imageGenerationMetadata) { + this.image = image; + this.imageGenerationMetadata = imageGenerationMetadata; + } + + @Override + public Image getOutput() { + return image; + } + + @Override + public ImageGenerationMetadata getResultMetadata() { + return imageGenerationMetadata; + } + + @Override + public String toString() { + return "ImageGeneration{" + "imageGenerationMetadata=" + imageGenerationMetadata + ", image=" + image + '}'; + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageGenerationMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageGenerationMetadata.java new file mode 100644 index 00000000000..7be56cfc959 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageGenerationMetadata.java @@ -0,0 +1,7 @@ +package org.springframework.ai.image; + +import org.springframework.ai.model.ResultMetadata; + +public interface ImageGenerationMetadata extends ResultMetadata { + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageMessage.java b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageMessage.java new file mode 100644 index 00000000000..153b606df5c --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageMessage.java @@ -0,0 +1,47 @@ +package org.springframework.ai.image; + +import java.util.Objects; + +public class ImageMessage { + + private String text; + + private Float weight; + + public ImageMessage(String text) { + this.text = text; + } + + public ImageMessage(String text, Float weight) { + this.text = text; + this.weight = weight; + } + + public String getText() { + return text; + } + + public Float getWeight() { + return weight; + } + + @Override + public String toString() { + return "mageMessage{" + "text='" + text + '\'' + ", weight=" + weight + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (!(o instanceof ImageMessage that)) + return false; + return Objects.equals(text, that.text) && Objects.equals(weight, that.weight); + } + + @Override + public int hashCode() { + return Objects.hash(text, weight); + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageOptions.java new file mode 100644 index 00000000000..5d8ededb006 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageOptions.java @@ -0,0 +1,17 @@ +package org.springframework.ai.image; + +import org.springframework.ai.model.ModelOptions; + +public interface ImageOptions extends ModelOptions { + + Integer getN(); + + String getModel(); + + Integer getWidth(); + + Integer getHeight(); + + String getResponseFormat(); // openai - url or base64 : stability ai byte[] or base64 + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageOptionsBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageOptionsBuilder.java new file mode 100644 index 00000000000..f5e56d96b9c --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageOptionsBuilder.java @@ -0,0 +1,103 @@ +package org.springframework.ai.image; + +public class ImageOptionsBuilder { + + private class ImageModelOptionsImpl implements ImageOptions { + + private Integer n; + + private String model; + + private Integer width; + + private Integer height; + + private String responseFormat; + + @Override + public Integer getN() { + return n; + } + + public void setN(Integer n) { + this.n = n; + } + + @Override + public String getModel() { + return model; + } + + public void setModel(String model) { + this.model = model; + } + + @Override + public String getResponseFormat() { + return responseFormat; + } + + public void setResponseFormat(String responseFormat) { + this.responseFormat = responseFormat; + } + + @Override + public Integer getWidth() { + return width; + } + + public void setWidth(Integer width) { + this.width = width; + } + + @Override + public Integer getHeight() { + return height; + } + + public void setHeight(Integer height) { + this.height = height; + } + + } + + private final ImageModelOptionsImpl options = new ImageModelOptionsImpl(); + + private ImageOptionsBuilder() { + + } + + public static ImageOptionsBuilder builder() { + return new ImageOptionsBuilder(); + } + + public ImageOptionsBuilder withN(Integer n) { + options.setN(n); + return this; + } + + public ImageOptionsBuilder withModel(String model) { + options.setModel(model); + return this; + } + + public ImageOptionsBuilder withResponseFormat(String responseFormat) { + options.setResponseFormat(responseFormat); + return this; + } + + public ImageOptionsBuilder withWidth(Integer width) { + options.setWidth(width); + return this; + } + + public ImageOptionsBuilder withHeight(Integer height) { + options.setHeight(height); + return this; + } + + public ImageOptions build() { + return options; + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/ImagePrompt.java b/spring-ai-core/src/main/java/org/springframework/ai/image/ImagePrompt.java new file mode 100644 index 00000000000..99589894c0a --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/ImagePrompt.java @@ -0,0 +1,65 @@ +package org.springframework.ai.image; + +import org.springframework.ai.model.ModelRequest; + +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +public class ImagePrompt implements ModelRequest> { + + private final List messages; + + private ImageOptions imageModelOptions; + + public ImagePrompt(List messages) { + this.messages = messages; + } + + public ImagePrompt(List messages, ImageOptions imageModelOptions) { + this.messages = messages; + this.imageModelOptions = imageModelOptions; + } + + public ImagePrompt(ImageMessage imageMessage, ImageOptions imageOptions) { + this(Collections.singletonList(imageMessage), imageOptions); + } + + public ImagePrompt(String instructions, ImageOptions imageOptions) { + this(new ImageMessage(instructions), imageOptions); + } + + public ImagePrompt(String instructions) { + this(new ImageMessage(instructions), ImageOptionsBuilder.builder().build()); + } + + @Override + public List getInstructions() { + return messages; + } + + @Override + public ImageOptions getOptions() { + return imageModelOptions; + } + + @Override + public String toString() { + return "NewImagePrompt{" + "messages=" + messages + ", imageModelOptions=" + imageModelOptions + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (!(o instanceof ImagePrompt that)) + return false; + return Objects.equals(messages, that.messages) && Objects.equals(imageModelOptions, that.imageModelOptions); + } + + @Override + public int hashCode() { + return Objects.hash(messages, imageModelOptions); + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageResponse.java b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageResponse.java new file mode 100644 index 00000000000..84c4e93ccdd --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageResponse.java @@ -0,0 +1,60 @@ +package org.springframework.ai.image; + +import org.springframework.ai.model.ModelResponse; +import org.springframework.ai.model.ResponseMetadata; + +import java.util.List; +import java.util.Objects; + +public class ImageResponse implements ModelResponse { + + private final ImageResponseMetadata imageResponseMetadata; + + private final List imageGenerations; + + public ImageResponse(List generations) { + this(generations, ImageResponseMetadata.NULL); + } + + public ImageResponse(List generations, ImageResponseMetadata imageResponseMetadata) { + this.imageResponseMetadata = imageResponseMetadata; + this.imageGenerations = List.copyOf(generations); + } + + @Override + public ImageGeneration getResult() { + return imageGenerations.get(0); + } + + @Override + public List getResults() { + return imageGenerations; + } + + @Override + public ImageResponseMetadata getMetadata() { + return imageResponseMetadata; + } + + @Override + public String toString() { + return "ImageResponse{" + "imageResponseMetadata=" + imageResponseMetadata + ", imageGenerations=" + + imageGenerations + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (!(o instanceof ImageResponse that)) + return false; + return Objects.equals(imageResponseMetadata, that.imageResponseMetadata) + && Objects.equals(imageGenerations, that.imageGenerations); + } + + @Override + public int hashCode() { + return Objects.hash(imageResponseMetadata, imageGenerations); + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageResponseMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageResponseMetadata.java new file mode 100644 index 00000000000..512303ca490 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageResponseMetadata.java @@ -0,0 +1,14 @@ +package org.springframework.ai.image; + +import org.springframework.ai.model.ResponseMetadata; + +public interface ImageResponseMetadata extends ResponseMetadata { + + ImageResponseMetadata NULL = new ImageResponseMetadata() { + }; + + default Long created() { + return System.currentTimeMillis(); + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/NewImageClient.java b/spring-ai-core/src/main/java/org/springframework/ai/image/NewImageClient.java new file mode 100644 index 00000000000..579e8e67376 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/NewImageClient.java @@ -0,0 +1,12 @@ +package org.springframework.ai.image; + +import org.springframework.ai.model.ModelClient; + +import java.util.List; + +@FunctionalInterface +public interface NewImageClient extends ModelClient { + + ImageResponse call(ImagePrompt request); + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelClient.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelClient.java new file mode 100644 index 00000000000..2d3d2f3bf95 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelClient.java @@ -0,0 +1,26 @@ +package org.springframework.ai.model; + +/** + * The ModelClient interface provides a generic API for invoking AI models. It is designed + * to handle the interaction with various types of AI models by abstracting the process of + * sending requests and receiving responses. The interface uses Java generics to + * accommodate different types of requests and responses, enhancing flexibility and + * adaptability across different AI model implementations. + * + * @param the generic type of the request to the AI model + * @param the generic type of the response from the AI model + * @author Mark Pollack + * @since 0.8.0 + */ +public interface ModelClient, TRes extends ModelResponse> { + + /** + * Executes a method call to the AI model. + * @param request the request object to be sent to the AI model + * @param the generic type of the request object + * @param the generic type of the response from the AI model + * @return the response from the AI model + */ + TRes call(TReq request); + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptions.java new file mode 100644 index 00000000000..ae01e7a770b --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptions.java @@ -0,0 +1,15 @@ +package org.springframework.ai.model; + +/** + * Interface representing the customizable options for AI model interactions. This + * interface allows for the specification of various settings and parameters that can + * influence the behavior and output of AI models. It is designed to provide flexibility + * and adaptability in different AI scenarios, ensuring that the AI models can be + * fine-tuned according to specific requirements. + * + * @author Mark Pollack + * @since 0.8.0 + */ +public interface ModelOptions { + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java new file mode 100644 index 00000000000..b84aa32cb2a --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java @@ -0,0 +1,72 @@ +package org.springframework.ai.model; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; + +import java.util.Map; +import java.util.stream.Collectors; + +public abstract class ModelOptionsUtils { + + private final static ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + + private ModelOptionsUtils() { + + } + + /** + * Merges the source object into the target object and returns an object represented + * by the given class. The source null values are ignored. + * @param they type of the class to return. + * @param source the source object to merge. + * @param target the target object to merge into. + * @param clazz the class to return. + * @return the merged object represented by the given class. + */ + public static T merge(Object source, Object target, Class clazz) { + Map sourceMap = objectToMap(source); + Map targetMap = objectToMap(target); + + targetMap.putAll(sourceMap.entrySet() + .stream() + .filter(e -> e.getValue() != null) + .collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue()))); + + return mapToClass(targetMap, clazz); + } + + /** + * Converts the given object to a Map. + * @param source the object to convert to a Map. + * @return the converted Map. + */ + public static Map objectToMap(Object source) { + try { + String json = OBJECT_MAPPER.writeValueAsString(source); + return OBJECT_MAPPER.readValue(json, new TypeReference>() { + }); + } + catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + /** + * Converts the given Map to the given class. + * @param the type of the class to return. + * @param source the Map to convert to the given class. + * @param clazz the class to convert the Map to. + * @return the converted class. + */ + public static T mapToClass(Map source, Class clazz) { + try { + String json = OBJECT_MAPPER.writeValueAsString(source); + return OBJECT_MAPPER.readValue(json, clazz); + } + catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelRequest.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelRequest.java new file mode 100644 index 00000000000..6019f3144c9 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelRequest.java @@ -0,0 +1,28 @@ +package org.springframework.ai.model; + +/** + * Interface representing a request to an AI model. This interface encapsulates the + * necessary information required to interact with an AI model, including instructions or + * inputs (of generic type T) and additional model options. It provides a standardized way + * to send requests to AI models, ensuring that all necessary details are included and can + * be easily managed. + * + * @param the type of instructions or input required by the AI model + * @author Mark Pollack + * @since 0.8.0 + */ +public interface ModelRequest { + + /** + * Retrieves the instructions or input required by the AI model. + * @return the instructions or input required by the AI model + */ + T getInstructions(); // required input + + /** + * Retrieves the customizable options for AI model interactions. + * @return the customizable options for AI model interactions + */ + ModelOptions getOptions(); + +} \ No newline at end of file diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelResponse.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelResponse.java new file mode 100644 index 00000000000..32143e93390 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelResponse.java @@ -0,0 +1,38 @@ +package org.springframework.ai.model; + +import java.util.List; + +/** + * Interface representing the response received from an AI model. This interface provides + * methods to access the main result or a list of results generated by the AI model, along + * with the response metadata. It serves as a standardized way to encapsulate and manage + * the output from AI models, ensuring easy retrieval and processing of the generated + * information. + * + * @param the type of the result(s) provided by the AI model + * @author Mark Pollack + * @since 0.8.0 + */ +public interface ModelResponse> { + + /** + * Retrieves the result of the AI model. + * @param the type of the result generated by the AI model + * @return the result generated by the AI model + */ + T getResult(); + + /** + * Retrieves the list of generated outputs by the AI model. + * @param the type of the generated outputs + * @return the list of generated outputs + */ + List getResults(); + + /** + * Retrieves the response metadata associated with the AI model's response. + * @return the response metadata + */ + ResponseMetadata getMetadata(); + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelResult.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelResult.java new file mode 100644 index 00000000000..8fba7ac7a3e --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelResult.java @@ -0,0 +1,27 @@ +package org.springframework.ai.model; + +/** + * This interface provides methods to access the main output of the AI model and the + * metadata associated with this result. It is designed to offer a standardized and + * comprehensive way to handle and interpret the outputs generated by AI models, catering + * to diverse AI applications and use cases. + * + * @param the type of the output generated by the AI model + * @author Mark Pollack + * @since 0.8.0 + */ +public interface ModelResult { + + /** + * Retrieves the output generated by the AI model. + * @return the output generated by the AI model + */ + T getOutput(); + + /** + * Retrieves the metadata associated with the result of an AI model. + * @return the metadata associated with the result + */ + ResultMetadata getResultMetadata(); + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ResponseMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ResponseMetadata.java new file mode 100644 index 00000000000..b824c4c141d --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ResponseMetadata.java @@ -0,0 +1,15 @@ +package org.springframework.ai.model; + +/** + * Interface representing metadata associated with an AI model's response. This interface + * is designed to provide additional information about the generative response from an AI + * model, including processing details and model-specific data. It serves as a value + * object within the core domain, enhancing the understanding and management of AI model + * responses in various applications. + * + * @author Mark Pollack + * @since 0.8.0 + */ +public interface ResponseMetadata { + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ResultMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ResultMetadata.java new file mode 100644 index 00000000000..54b6061f6f4 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ResultMetadata.java @@ -0,0 +1,15 @@ +package org.springframework.ai.model; + +/** + * Interface representing metadata associated with the results of an AI model. This + * interface focuses on providing additional context and insights into the results + * generated by AI models. It could include information like computation time, model + * version, or other relevant details that enhance understanding and management of AI + * model outputs in various applications. + * + * @author Mark Pollack + * @since 0.8.0 + */ +public interface ResultMetadata { + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/package-info.java b/spring-ai-core/src/main/java/org/springframework/ai/model/package-info.java new file mode 100644 index 00000000000..12eaa53b400 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/package-info.java @@ -0,0 +1,11 @@ +/** + * Provides a set of interfaces and classes for a generic API designed to interact with + * various AI models. This package includes interfaces for handling AI model calls, + * requests, responses, results, and associated metadata. It is designed to offer a + * flexible and adaptable framework for interacting with different types of AI models, + * abstracting the complexities involved in model invocation and result processing. The + * use of generics enhances the API's capability to work with a wide range of models, + * ensuring a broad applicability across diverse AI scenarios. + * + */ +package org.springframework.ai.model; \ No newline at end of file diff --git a/spring-ai-core/src/main/java/org/springframework/ai/parser/FormatProvider.java b/spring-ai-core/src/main/java/org/springframework/ai/parser/FormatProvider.java index 2770811506a..8d0cb76b099 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/parser/FormatProvider.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/parser/FormatProvider.java @@ -18,7 +18,7 @@ /** * Implementations of this interface provides instructions for how the output of a - * language model should be formatted. + * language generative should be formatted. * * @author Mark Pollack */ @@ -26,7 +26,7 @@ public interface FormatProvider { /** * @return Returns a string containing instructions for how the output of a language - * model should be formatted. + * generative should be formatted. */ String getFormat(); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/transformer/KeywordMetadataEnricher.java b/spring-ai-core/src/main/java/org/springframework/ai/transformer/KeywordMetadataEnricher.java index 9387be7444b..40280f19269 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/transformer/KeywordMetadataEnricher.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/transformer/KeywordMetadataEnricher.java @@ -22,12 +22,12 @@ import org.springframework.ai.chat.ChatClient; import org.springframework.ai.document.Document; import org.springframework.ai.document.DocumentTransformer; -import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.PromptTemplate; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.util.Assert; /** - * Keyword extractor that uses model to extract 'excerpt_keywords' metadata field. + * Keyword extractor that uses generative to extract 'excerpt_keywords' metadata field. * * @author Christian Tzolov */ @@ -65,7 +65,7 @@ public List apply(List documents) { var template = new PromptTemplate(String.format(KEYWORDS_TEMPLATE, keywordCount)); Prompt prompt = template.create(Map.of(CONTEXT_STR_PLACEHOLDER, document.getContent())); - String keywords = this.chatClient.generate(prompt).getGeneration().getContent(); + String keywords = this.chatClient.call(prompt).getResult().getOutput().getContent(); document.getMetadata().putAll(Map.of(EXCERPT_KEYWORDS_METADATA_KEY, keywords)); } return documents; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/transformer/SummaryMetadataEnricher.java b/spring-ai-core/src/main/java/org/springframework/ai/transformer/SummaryMetadataEnricher.java index e5162e1b776..f0a607dbda0 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/transformer/SummaryMetadataEnricher.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/transformer/SummaryMetadataEnricher.java @@ -25,14 +25,14 @@ import org.springframework.ai.document.Document; import org.springframework.ai.document.DocumentTransformer; import org.springframework.ai.document.MetadataMode; -import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.PromptTemplate; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; /** - * Title extractor with adjacent sharing that uses model to extract 'section_summary', - * 'prev_section_summary', 'next_section_summary' metadata fields. + * Title extractor with adjacent sharing that uses generative to extract + * 'section_summary', 'prev_section_summary', 'next_section_summary' metadata fields. * * @author Christian Tzolov */ @@ -102,7 +102,7 @@ public List apply(List documents) { Prompt prompt = new PromptTemplate(this.summaryTemplate) .create(Map.of(CONTEXT_STR_PLACEHOLDER, documentContext)); - documentSummaries.add(this.chatClient.generate(prompt).getGeneration().getContent()); + documentSummaries.add(this.chatClient.call(prompt).getResult().getOutput().getContent()); } for (int i = 0; i < documentSummaries.size(); i++) { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/Filter.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/Filter.java index 02c18f6b26a..04f6adaecf2 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/Filter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/Filter.java @@ -17,12 +17,13 @@ package org.springframework.ai.vectorstore.filter; /** - * Portable runtime model for metadata filter expressions. This generic model is used to - * define store agnostic filter expressions than later can be converted into vector-store - * specific, native, expressions. + * Portable runtime generative for metadata filter expressions. This generic generative is + * used to define store agnostic filter expressions than later can be converted into + * vector-store specific, native, expressions. * - * The expression model supports constant comparison {@code (e.g. ==, !=, <, <=, >, >=) }, - * IN/NON-IN checks and AND and OR to compose multiple expressions. + * The expression generative supports constant comparison + * {@code (e.g. ==, !=, <, <=, >, >=) }, IN/NON-IN checks and AND and OR to compose + * multiple expressions. * * For example: * diff --git a/spring-ai-core/src/main/resources/embedding/embedding-model-dimensions.properties b/spring-ai-core/src/main/resources/embedding/embedding-model-dimensions.properties index 89f3b63cd3a..849a8ceacf4 100644 --- a/spring-ai-core/src/main/resources/embedding/embedding-model-dimensions.properties +++ b/spring-ai-core/src/main/resources/embedding/embedding-model-dimensions.properties @@ -1,4 +1,4 @@ -# Map of embedding model names and their dimensions +# Map of embedding generative names and their dimensions # OpenAI text-embedding-ada-002=1536 text-similarity-ada-001=1024 diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/ChatClientTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/ChatClientTests.java index 8a796a06b93..1a85939955c 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/ChatClientTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/ChatClientTests.java @@ -21,20 +21,13 @@ import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isA; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.doCallRealMethod; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyNoMoreInteractions; - -import java.util.Collections; +import static org.mockito.Mockito.*; import org.junit.jupiter.api.Test; import org.mockito.Mockito; -import org.springframework.ai.prompt.Prompt; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.prompt.Prompt; /** * Unit Tests for {@link ChatClient}. @@ -51,10 +44,23 @@ void generateWithStringCallsGenerateWithPromptAndReturnsResponseCorrectly() { String responseMessage = "All your bases are belong to us"; ChatClient mockClient = Mockito.mock(ChatClient.class); - Generation generation = spy(new Generation(responseMessage)); - ChatResponse response = spy(new ChatResponse(Collections.singletonList(generation))); - doCallRealMethod().when(mockClient).generate(anyString()); + AssistantMessage mockAssistantMessage = Mockito.mock(AssistantMessage.class); + when(mockAssistantMessage.getContent()).thenReturn(responseMessage); + + // Create a mock Generation + Generation generation = Mockito.mock(Generation.class); + when(generation.getOutput()).thenReturn(mockAssistantMessage); + + // Create a mock ChatResponse with the mock Generation + ChatResponse response = Mockito.mock(ChatResponse.class); + when(response.getResult()).thenReturn(generation); + + // Generation generation = spy(new Generation(responseMessage)); + // ChatResponse response = spy(new + // ChatResponse(Collections.singletonList(generation))); + + doCallRealMethod().when(mockClient).call(anyString()); doAnswer(invocationOnMock -> { @@ -65,14 +71,15 @@ void generateWithStringCallsGenerateWithPromptAndReturnsResponseCorrectly() { return response; - }).when(mockClient).generate(any(Prompt.class)); + }).when(mockClient).call(any(Prompt.class)); - assertThat(mockClient.generate(userMessage)).isEqualTo(responseMessage); + assertThat(mockClient.call(userMessage)).isEqualTo(responseMessage); - verify(mockClient, times(1)).generate(eq(userMessage)); - verify(mockClient, times(1)).generate(isA(Prompt.class)); - verify(response, times(1)).getGeneration(); - verify(generation, times(1)).getContent(); + verify(mockClient, times(1)).call(eq(userMessage)); + verify(mockClient, times(1)).call(isA(Prompt.class)); + verify(response, times(1)).getResult(); + verify(generation, times(1)).getOutput(); + verify(mockAssistantMessage, times(1)).getContent(); verifyNoMoreInteractions(mockClient, generation, response); } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/embedding/EmbeddingUtilTests.java b/spring-ai-core/src/test/java/org/springframework/ai/embedding/EmbeddingUtilTests.java index 45d1e68c446..c8f88d0a1e6 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/embedding/EmbeddingUtilTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/embedding/EmbeddingUtilTests.java @@ -18,6 +18,7 @@ import java.util.List; +import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.params.ParameterizedTest; diff --git a/spring-ai-core/src/test/java/org/springframework/ai/metadata/PromptMetadataTests.java b/spring-ai-core/src/test/java/org/springframework/ai/metadata/PromptMetadataTests.java index 792b093842a..aef94a6b030 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/metadata/PromptMetadataTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/metadata/PromptMetadataTests.java @@ -22,7 +22,8 @@ import org.junit.jupiter.api.Test; -import org.springframework.ai.metadata.PromptMetadata.PromptFilterMetadata; +import org.springframework.ai.chat.metadata.PromptMetadata; +import org.springframework.ai.chat.metadata.PromptMetadata.PromptFilterMetadata; /** * Unit Tests for {@link PromptMetadata}. diff --git a/spring-ai-core/src/test/java/org/springframework/ai/metadata/UsageTests.java b/spring-ai-core/src/test/java/org/springframework/ai/metadata/UsageTests.java index 0507d4c7f16..0e44b63d64f 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/metadata/UsageTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/metadata/UsageTests.java @@ -25,6 +25,7 @@ import static org.mockito.Mockito.verifyNoMoreInteractions; import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.metadata.Usage; /** * Unit Tests for {@link Usage}. diff --git a/spring-ai-core/src/test/java/org/springframework/ai/prompt/PromptTemplateTest.java b/spring-ai-core/src/test/java/org/springframework/ai/prompt/PromptTemplateTest.java index 74abea55eb2..52dcd7712a6 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/prompt/PromptTemplateTest.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/prompt/PromptTemplateTest.java @@ -2,6 +2,7 @@ import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.core.io.InputStreamResource; import org.springframework.core.io.Resource; @@ -18,17 +19,18 @@ public class PromptTemplateTest { @Test public void testRender() { - // Create a map with string keys and object values to serve as a model for testing + // Create a map with string keys and object values to serve as a generative for + // testing Map model = new HashMap<>(); model.put("key1", "value1"); model.put("key2", true); model.put("key3", 100); - // Create a simple template with placeholders for keys in the model + // Create a simple template with placeholders for keys in the generative String template = "This is a {key1}, it is {key2}, and it costs {key3}"; PromptTemplate promptTemplate = new PromptTemplate(template, model); - // The expected result after rendering the template with the model + // The expected result after rendering the template with the generative String expected = "This is a value1, it is true, and it costs 100"; String result = promptTemplate.render(); @@ -44,7 +46,8 @@ public void testRender() { @Disabled("Need to improve PromptTemplate to better handle Resource toString and tracking with 'dynamicModel' for underlying StringTemplate") @Test public void testRenderResource() throws Exception { - // Create a map with string keys and object values to serve as a model for testing + // Create a map with string keys and object values to serve as a generative for + // testing Map model = new HashMap<>(); model.put("key1", "value1"); model.put("key2", true); @@ -55,11 +58,11 @@ public void testRenderResource() throws Exception { model.put("key3", resource); - // Create a simple template with placeholders for keys in the model + // Create a simple template with placeholders for keys in the generative String template = "{key1}, {key2}, {key3}"; PromptTemplate promptTemplate = new PromptTemplate(template, model); - // The expected result after rendering the template with the model + // The expected result after rendering the template with the generative String expected = "value1, true, it costs 100"; String result = promptTemplate.render(); @@ -69,11 +72,12 @@ public void testRenderResource() throws Exception { @Test public void testRenderFailure() { - // Create a map with string keys and object values to serve as a model for testing + // Create a map with string keys and object values to serve as a generative for + // testing Map model = new HashMap<>(); model.put("key1", "value1"); - // Create a simple template that includes a key not present in the model + // Create a simple template that includes a key not present in the generative String template = "This is a {key2}!"; PromptTemplate promptTemplate = new PromptTemplate(template, model); diff --git a/spring-ai-core/src/test/java/org/springframework/ai/prompt/PromptTests.java b/spring-ai-core/src/test/java/org/springframework/ai/prompt/PromptTests.java index 22d46555f1a..b08ed9aec70 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/prompt/PromptTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/prompt/PromptTests.java @@ -18,6 +18,9 @@ import org.assertj.core.api.Assertions; import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; import java.util.HashMap; import java.util.Map; @@ -54,7 +57,7 @@ void newApiPlaygroundTests() { // to have access to Messages Prompt prompt = pt.create(model); assertThat(prompt.getContents()).isNotNull(); - assertThat(prompt.getMessages()).isNotEmpty().hasSize(1); + assertThat(prompt.getInstructions()).isNotEmpty().hasSize(1); System.out.println(prompt.getContents()); String systemTemplate = "You are a helpful assistant that translates {input_language} to {output_language}."; @@ -86,7 +89,7 @@ void newApiPlaygroundTests() { // ChatPromptTemplate chatPromptTemplate = new ChatPromptTemplate(systemPrompt, // humanPrompt); - // Prompt chatPrompt chatPromptTemplate.create(model); + // Prompt chatPrompt chatPromptTemplate.create(generative); } diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/onnx.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/onnx.adoc index 722fb3b7f7d..d8f87898c67 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/onnx.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/onnx.adoc @@ -22,7 +22,7 @@ python3 -m venv venv source ./venv/bin/activate (venv) pip install --upgrade pip (venv) pip install optimum onnx onnxruntime -(venv) optimum-cli export onnx --model sentence-transformers/all-MiniLM-L6-v2 onnx-output-folder +(venv) optimum-cli export onnx --generative sentence-transformers/all-MiniLM-L6-v2 onnx-output-folder ---- The snippet exports the https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2[sentence-transformers/all-MiniLM-L6-v2] transformer into the `onnx-output-folder` folder. Later includes the `tokenizer.json` and `model.onnx` files used by the embedding client. diff --git a/spring-ai-spring-boot-autoconfigure/pom.xml b/spring-ai-spring-boot-autoconfigure/pom.xml index 97d136c5d57..c188fa663c5 100644 --- a/spring-ai-spring-boot-autoconfigure/pom.xml +++ b/spring-ai-spring-boot-autoconfigure/pom.xml @@ -156,6 +156,14 @@ true + + + org.springframework.ai + spring-ai-stability-ai + ${project.parent.version} + true + + org.springframework.ai diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiChatProperties.java index 57d3dac2802..fadcafe1b79 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiChatProperties.java @@ -34,11 +34,11 @@ public class AzureOpenAiChatProperties { /** * An alternative to sampling with temperature called nucleus sampling. This value - * causes the model to consider the results of tokens with the provided probability - * mass. As an example, a value of 0.15 will cause only the tokens comprising the top - * 15% of probability mass to be considered. It is not recommended to modify - * temperature and top_p for the same completions request as the interaction of these - * two settings is difficult to predict. + * causes the generative to consider the results of tokens with the provided + * probability mass. As an example, a value of 0.15 will cause only the tokens + * comprising the top 15% of probability mass to be considered. It is not recommended + * to modify temperature and top_p for the same completions request as the interaction + * of these two settings is difficult to predict. */ private Double topP; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiEmbeddingProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiEmbeddingProperties.java index 7ba71f1a96c..c7ddee8537c 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiEmbeddingProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiEmbeddingProperties.java @@ -24,7 +24,7 @@ public class AzureOpenAiEmbeddingProperties { public static final String CONFIG_PREFIX = "spring.ai.azure.openai.embedding"; /** - * The text embedding model to use for the embedding client. + * The text embedding generative to use for the embedding client. */ private String model = "text-embedding-ada-002"; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatProperties.java index 5b22fff7b60..0b3d36f9857 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatProperties.java @@ -39,23 +39,24 @@ public class BedrockAnthropicChatProperties { private boolean enabled = false; /** - * The model id to use. See the {@link AnthropicChatModel} for the supported models. + * The generative id to use. See the {@link AnthropicChatModel} for the supported + * models. */ private String model = AnthropicChatModel.CLAUDE_V2.id(); /** * Controls the randomness of the output. Values can range over [0.0,1.0], inclusive. * A value closer to 1.0 will produce responses that are more varied, while a value - * closer to 0.0 will typically result in less surprising responses from the model. - * This value specifies default to be used by the backend while making the call to the - * model. + * closer to 0.0 will typically result in less surprising responses from the + * generative. This value specifies default to be used by the backend while making the + * call to the generative. */ private Float temperature = 0.7f; /** - * The maximum cumulative probability of tokens to consider when sampling. The model - * uses combined Top-k and nucleus sampling. Nucleus sampling considers the smallest - * set of tokens whose probability sum is at least topP. + * The maximum cumulative probability of tokens to consider when sampling. The + * generative uses combined Top-k and nucleus sampling. Nucleus sampling considers the + * smallest set of tokens whose probability sum is at least topP. */ private Float topP = null; @@ -68,19 +69,19 @@ public class BedrockAnthropicChatProperties { private Integer maxTokensToSample = 300; /** - * Specify the number of token choices the model uses to generate the next token. + * Specify the number of token choices the generative uses to generate the next token. */ private Integer topK = 10; /** - * Configure up to four sequences that the model recognizes. After a stop sequence, - * the model stops generating further tokens. The returned text doesn't contain the - * stop sequence. + * Configure up to four sequences that the generative recognizes. After a stop + * sequence, the generative stops generating further tokens. The returned text doesn't + * contain the stop sequence. */ private List stopSequences = List.of("\n\nHuman:"); /** - * The version of the model to use. The default value is bedrock-2023-05-31. + * The version of the generative to use. The default value is bedrock-2023-05-31. */ private String anthropicVersion = AnthropicChatBedrockApi.DEFAULT_ANTHROPIC_VERSION; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatProperties.java index 97d58bbf3de..0985d55b109 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatProperties.java @@ -40,7 +40,7 @@ public class BedrockCohereChatProperties { private boolean enabled = false; /** - * Bedrock Cohere Chat model name. Defaults to 'cohere-command-v14'. + * Bedrock Cohere Chat generative name. Defaults to 'cohere-command-v14'. */ private String model = CohereChatBedrockApi.CohereChatModel.COHERE_COMMAND_V14.id(); @@ -52,14 +52,14 @@ public class BedrockCohereChatProperties { /** * (optional) The maximum cumulative probability of tokens to consider when sampling. - * The model uses combined Top-k and nucleus sampling. Nucleus sampling considers the - * smallest set of tokens whose probability sum is at least topP. + * The generative uses combined Top-k and nucleus sampling. Nucleus sampling considers + * the smallest set of tokens whose probability sum is at least topP. */ private Float topP; /** - * (optional) Specify the number of token choices the model uses to generate the next - * token. + * (optional) Specify the number of token choices the generative uses to generate the + * next token. */ private Integer topK; @@ -69,9 +69,9 @@ public class BedrockCohereChatProperties { private Integer maxTokens; /** - * (optional) Configure up to four sequences that the model recognizes. After a stop - * sequence, the model stops generating further tokens. The returned text doesn't - * contain the stop sequence. + * (optional) Configure up to four sequences that the generative recognizes. After a + * stop sequence, the generative stops generating further tokens. The returned text + * doesn't contain the stop sequence. */ private List stopSequences; @@ -81,19 +81,19 @@ public class BedrockCohereChatProperties { private ReturnLikelihoods returnLikelihoods; /** - * (optional) The maximum number of generations that the model should return. + * (optional) The maximum number of generations that the generative should return. */ private Integer numGenerations; /** - * LogitBias prevents the model from generating unwanted tokens or incentivize the - * model to include desired tokens. The token likelihoods. + * LogitBias prevents the generative from generating unwanted tokens or incentivize + * the generative to include desired tokens. The token likelihoods. */ private String logitBiasToken; /** - * LogitBias prevents the model from generating unwanted tokens or incentivize the - * model to include desired tokens. A float between -10 and 10. + * LogitBias prevents the generative from generating unwanted tokens or incentivize + * the generative to include desired tokens. A float between -10 and 10. */ private Float logitBiasBias; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingProperties.java index 7197d99ed8b..fc1c56996a3 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingProperties.java @@ -38,7 +38,8 @@ public class BedrockCohereEmbeddingProperties { private boolean enabled = false; /** - * Bedrock Cohere Embedding model name. Defaults to 'cohere.embed-multilingual-v3'. + * Bedrock Cohere Embedding generative name. Defaults to + * 'cohere.embed-multilingual-v3'. */ private String model = CohereEmbeddingModel.COHERE_EMBED_MULTILINGUAL_V1.id(); diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama2/BedrockLlama2ChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama2/BedrockLlama2ChatProperties.java index 4c7b05ebef6..9d2c3d7e28f 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama2/BedrockLlama2ChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama2/BedrockLlama2ChatProperties.java @@ -38,27 +38,27 @@ public class BedrockLlama2ChatProperties { /** * Controls the randomness of the output. Values can range over [0.0,1.0], inclusive. * A value closer to 1.0 will produce responses that are more varied, while a value - * closer to 0.0 will typically result in less surprising responses from the model. - * This value specifies default to be used by the backend while making the call to the - * model. + * closer to 0.0 will typically result in less surprising responses from the + * generative. This value specifies default to be used by the backend while making the + * call to the generative. */ private Float temperature = 0.7f; /** - * The maximum cumulative probability of tokens to consider when sampling. The model - * uses combined Top-k and nucleus sampling. Nucleus sampling considers the smallest - * set of tokens whose probability sum is at least topP. + * The maximum cumulative probability of tokens to consider when sampling. The + * generative uses combined Top-k and nucleus sampling. Nucleus sampling considers the + * smallest set of tokens whose probability sum is at least topP. */ private Float topP = null; /** - * Specify the maximum number of tokens to use in the generated response. The model - * truncates the response once the generated text exceeds maxGenLen. + * Specify the maximum number of tokens to use in the generated response. The + * generative truncates the response once the generated text exceeds maxGenLen. */ private Integer maxGenLen = 300; /** - * The model id to use. See the {@link Llama2ChatModel} for the supported models. + * The generative id to use. See the {@link Llama2ChatModel} for the supported models. */ private String model = Llama2ChatModel.LLAMA2_70B_CHAT_V1.id(); diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatProperties.java index f72adf15d47..60d2ce71291 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatProperties.java @@ -38,7 +38,7 @@ public class BedrockTitanChatProperties { private boolean enabled = false; /** - * Bedrock Titan Chat model name. Defaults to 'amazon.titan-text-express-v1'. + * Bedrock Titan Chat generative name. Defaults to 'amazon.titan-text-express-v1'. */ private String model = TitanChatModel.TITAN_TEXT_EXPRESS_V1.id(); @@ -50,8 +50,8 @@ public class BedrockTitanChatProperties { /** * (optional) The maximum cumulative probability of tokens to consider when sampling. - * The model uses combined Top-k and nucleus sampling. Nucleus sampling considers the - * smallest set of tokens whose probability sum is at least topP. + * The generative uses combined Top-k and nucleus sampling. Nucleus sampling considers + * the smallest set of tokens whose probability sum is at least topP. */ private Float topP; @@ -61,9 +61,9 @@ public class BedrockTitanChatProperties { private Integer maxTokenCount; /** - * (optional) Configure up to four sequences that the model recognizes. After a stop - * sequence, the model stops generating further tokens. The returned text doesn't - * contain the stop sequence. + * (optional) Configure up to four sequences that the generative recognizes. After a + * stop sequence, the generative stops generating further tokens. The returned text + * doesn't contain the stop sequence. */ private List stopSequences; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingProperties.java index f74026c813c..1a6b1e5c141 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingProperties.java @@ -37,7 +37,7 @@ public class BedrockTitanEmbeddingProperties { private boolean enabled = false; /** - * Bedrock Titan Embedding model name. Defaults to 'amazon.titan-embed-image-v1'. + * Bedrock Titan Embedding generative name. Defaults to 'amazon.titan-embed-image-v1'. */ private String model = TitanEmbeddingModel.TITAN_EMBED_IMAGE_V1.id(); diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaChatProperties.java index ed8762fc3de..1b07aca6f7b 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaChatProperties.java @@ -31,14 +31,14 @@ public class OllamaChatProperties { public static final String CONFIG_PREFIX = "spring.ai.ollama.chat"; /** - * Ollama Chat model name. Defaults to 'llama2'. + * Ollama Chat generative name. Defaults to 'llama2'. */ private String model = "llama2"; /** - * Client lever Ollama options. Use this property to configure model temperature, topK - * and topP and alike parameters. The null values are ignored defaulting to the - * model's defaults. + * Client lever Ollama options. Use this property to configure generative temperature, + * topK and topP and alike parameters. The null values are ignored defaulting to the + * generative's defaults. */ private OllamaOptions options = new OllamaOptions(); diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingProperties.java index a4dd65a7ac1..f4039bf57f5 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingProperties.java @@ -31,14 +31,14 @@ public class OllamaEmbeddingProperties { public static final String CONFIG_PREFIX = "spring.ai.ollama.embedding"; /** - * Ollama Embedding model name. Defaults to 'llama2'. + * Ollama Embedding generative name. Defaults to 'llama2'. */ private String model = "llama2"; /** - * Client lever Ollama options. Use this property to configure model temperature, topK - * and topP and alike parameters. The null values are ignored defaulting to the - * model's defaults. + * Client lever Ollama options. Use this property to configure generative temperature, + * topK and topP and alike parameters. The null values are ignored defaulting to the + * generative's defaults. */ private OllamaOptions options = new OllamaOptions(); diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiImageAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiImageAutoConfiguration.java new file mode 100644 index 00000000000..485c2eb841f --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiImageAutoConfiguration.java @@ -0,0 +1,33 @@ +package org.springframework.ai.autoconfigure.stabilityai; + +import org.springframework.ai.autoconfigure.NativeHints; +import org.springframework.ai.stabilityai.StabilityAiImageClient; +import org.springframework.ai.stabilityai.api.StabilityAiApi; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.ImportRuntimeHints; + +@AutoConfiguration +@ConditionalOnClass(StabilityAiApi.class) +@EnableConfigurationProperties({ StabilityAiProperties.class }) +@ImportRuntimeHints(NativeHints.class) +public class StabilityAiImageAutoConfiguration { + + @Bean + @ConditionalOnMissingBean + public StabilityAiApi stabilityAiApi(StabilityAiProperties stabilityAiProperties) { + return new StabilityAiApi(stabilityAiProperties.getApiKey(), stabilityAiProperties.getBaseUrl(), + stabilityAiProperties.getOptions().getModel()); + } + + @Bean + @ConditionalOnMissingBean + public StabilityAiImageClient stabilityAiImageClient(StabilityAiApi stabilityAiApi, + StabilityAiProperties stabilityAiProperties) { + return new StabilityAiImageClient(stabilityAiApi, stabilityAiProperties.getOptions()); + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiProperties.java new file mode 100644 index 00000000000..5955999dbf2 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiProperties.java @@ -0,0 +1,44 @@ +package org.springframework.ai.autoconfigure.stabilityai; + +import org.springframework.ai.stabilityai.api.StabilityAiApi; +import org.springframework.ai.stabilityai.api.StabilityAiImageOptions; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.context.properties.NestedConfigurationProperty; + +@ConfigurationProperties(StabilityAiProperties.CONFIG_PREFIX) +public class StabilityAiProperties { + + public static final String CONFIG_PREFIX = "spring.ai.stabilityai"; + + private String apiKey; + + private String baseUrl = StabilityAiApi.DEFAULT_BASE_URL; + + @NestedConfigurationProperty + private StabilityAiImageOptions options; + + public String getApiKey() { + return apiKey; + } + + public void setApiKey(String apiKey) { + this.apiKey = apiKey; + } + + public String getBaseUrl() { + return baseUrl; + } + + public void setBaseUrl(String baseUrl) { + this.baseUrl = baseUrl; + } + + public StabilityAiImageOptions getOptions() { + return options; + } + + public void setOptions(StabilityAiImageOptions options) { + this.options = options; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/transformers/TransformersEmbeddingClientProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/transformers/TransformersEmbeddingClientProperties.java index 8bc236270e7..6b1fa816df1 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/transformers/TransformersEmbeddingClientProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/transformers/TransformersEmbeddingClientProperties.java @@ -39,7 +39,7 @@ public class TransformersEmbeddingClientProperties { public static final String CONFIG_PREFIX = "spring.ai.embedding.transformer"; public static final String DEFAULT_CACHE_DIRECTORY = new File(System.getProperty("java.io.tmpdir"), - "spring-ai-onnx-model") + "spring-ai-onnx-generative") .getAbsolutePath(); /** @@ -91,7 +91,7 @@ public static class Cache { /** * Resource cache directory. Used to cache remote resources, such as the ONNX * models, to the local file system. Applicable only for cache.enabled == true. - * Defaults to {java.io.tmpdir}/spring-ai-onnx-model. + * Defaults to {java.io.tmpdir}/spring-ai-onnx-generative. */ private String directory = DEFAULT_CACHE_DIRECTORY; @@ -125,7 +125,7 @@ public Cache getCache() { public static class Onnx { /** - * Existing, pre-trained ONNX model. Commonly exported from + * Existing, pre-trained ONNX generative. Commonly exported from * https://sbert.net/docs/pretrained_models.html. Defaults to * sentence-transformers/all-MiniLM-L6-v2. */ diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/VertexAiChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/VertexAiChatProperties.java index a53f6567098..67b7ba714c3 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/VertexAiChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/VertexAiChatProperties.java @@ -27,16 +27,16 @@ public class VertexAiChatProperties { /** * Controls the randomness of the output. Values can range over [0.0,1.0], inclusive. * A value closer to 1.0 will produce responses that are more varied, while a value - * closer to 0.0 will typically result in less surprising responses from the model. - * This value specifies default to be used by the backend while making the call to the - * model. + * closer to 0.0 will typically result in less surprising responses from the + * generative. This value specifies default to be used by the backend while making the + * call to the generative. */ private Float temperature = 0.7f; /** - * The maximum cumulative probability of tokens to consider when sampling. The model - * uses combined Top-k and nucleus sampling. Nucleus sampling considers the smallest - * set of tokens whose probability sum is at least topP. + * The maximum cumulative probability of tokens to consider when sampling. The + * generative uses combined Top-k and nucleus sampling. Nucleus sampling considers the + * smallest set of tokens whose probability sum is at least topP. */ private Float topP = null; @@ -47,14 +47,14 @@ public class VertexAiChatProperties { private Integer candidateCount = 1; /** - * The maximum number of tokens to consider when sampling. The model uses combined - * Top-k and nucleus sampling. Top-k sampling considers the set of topK most probable - * tokens. + * The maximum number of tokens to consider when sampling. The generative uses + * combined Top-k and nucleus sampling. Top-k sampling considers the set of topK most + * probable tokens. */ private Integer topK = 20; /** - * Vertex AI PaLM API model name. Defaults to chat-bison-001 + * Vertex AI PaLM API generative name. Defaults to chat-bison-001 */ private String model = VertexAiApi.DEFAULT_GENERATE_MODEL; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/VertexAiEmbeddingProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/VertexAiEmbeddingProperties.java index dcdb8d140f4..433e9039248 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/VertexAiEmbeddingProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/VertexAiEmbeddingProperties.java @@ -25,7 +25,7 @@ public class VertexAiEmbeddingProperties { public static final String CONFIG_PREFIX = "spring.ai.vertex.ai.embedding"; /** - * Vertex AI PaLM API embedding model name. Defaults to embedding-gecko-001. + * Vertex AI PaLM API embedding generative name. Defaults to embedding-gecko-001. */ private String model = VertexAiApi.DEFAULT_EMBEDDING_MODEL; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports b/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports index 3786b5377fc..240fdc4db64 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports +++ b/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports @@ -1,5 +1,6 @@ org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiAutoConfiguration +org.springframework.ai.autoconfigure.stabilityai.StabilityAiImageAutoConfiguration org.springframework.ai.autoconfigure.transformers.TransformersEmbeddingClientAutoConfiguration org.springframework.ai.autoconfigure.huggingface.HuggingfaceChatAutoConfiguration org.springframework.ai.autoconfigure.vertexai.VertexAiAutoConfiguration diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java index 9caf7368bba..4fe2cf7c9c3 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java @@ -22,6 +22,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.chat.messages.AssistantMessage; import reactor.core.publisher.Flux; import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiAutoConfiguration; @@ -30,10 +31,10 @@ import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.Generation; import org.springframework.ai.embedding.EmbeddingResponse; -import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.SystemPromptTemplate; -import org.springframework.ai.prompt.messages.Message; -import org.springframework.ai.prompt.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -56,11 +57,11 @@ public class AzureOpenAiAutoConfigurationIT { "spring.ai.azure.openai.api-key=" + System.getenv("AZURE_OPENAI_API_KEY"), "spring.ai.azure.openai.endpoint=" + System.getenv("AZURE_OPENAI_ENDPOINT"), - "spring.ai.azure.openai.chat.model=" + CHAT_MODEL_NAME, + "spring.ai.azure.openai.chat.generative=" + CHAT_MODEL_NAME, "spring.ai.azure.openai.chat.temperature=0.8", "spring.ai.azure.openai.chat.maxTokens=123", - "spring.ai.azure.openai.embedding.model=" + EMBEDDING_MODEL_NAME + "spring.ai.azure.openai.embedding.generative=" + EMBEDDING_MODEL_NAME // @formatter:on ).withConfiguration(AutoConfigurations.of(AzureOpenAiAutoConfiguration.class)); @@ -78,8 +79,8 @@ public class AzureOpenAiAutoConfigurationIT { public void chatCompletion() { contextRunner.run(context -> { AzureOpenAiChatClient chatClient = context.getBean(AzureOpenAiChatClient.class); - ChatResponse response = chatClient.generate(new Prompt(List.of(userMessage, systemMessage))); - assertThat(response.getGeneration().getContent()).contains("Blackbeard"); + ChatResponse response = chatClient.call(new Prompt(List.of(userMessage, systemMessage))); + assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); }); } @@ -95,9 +96,10 @@ public void chatCompletionStreaming() { assertThat(responses.size()).isGreaterThan(1); String stitchedResponseContent = responses.stream() - .map(ChatResponse::getGenerations) + .map(ChatResponse::getResults) .flatMap(List::stream) - .map(Generation::getContent) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) .collect(Collectors.joining()); assertThat(stitchedResponseContent).contains("Blackbeard"); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationPropertyTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationPropertyTests.java index 5a77074b8e1..f75fff6507a 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationPropertyTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationPropertyTests.java @@ -40,7 +40,7 @@ public void chatPropertiesTest() { // @formatter:off "spring.ai.azure.openai.api-key=TEST_API_KEY", "spring.ai.azure.openai.endpoint=TEST_ENDPOINT", - "spring.ai.azure.openai.chat.model=MODEL_XYZ", + "spring.ai.azure.openai.chat.generative=MODEL_XYZ", "spring.ai.azure.openai.chat.temperature=0.55", "spring.ai.azure.openai.chat.topP=0.56", "spring.ai.azure.openai.chat.maxTokens=123") @@ -66,7 +66,8 @@ public void embeddingPropertiesTest() { new ApplicationContextRunner() .withPropertyValues("spring.ai.azure.openai.api-key=TEST_API_KEY", - "spring.ai.azure.openai.endpoint=TEST_ENDPOINT", "spring.ai.azure.openai.embedding.model=MODEL_XYZ") + "spring.ai.azure.openai.endpoint=TEST_ENDPOINT", + "spring.ai.azure.openai.embedding.generative=MODEL_XYZ") .withConfiguration(AutoConfigurations.of(AzureOpenAiAutoConfiguration.class)) .run(context -> { var chatProperties = context.getBean(AzureOpenAiEmbeddingProperties.class); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatAutoConfigurationIT.java index 416bb867695..9e525b62f72 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatAutoConfigurationIT.java @@ -22,6 +22,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.chat.messages.AssistantMessage; import reactor.core.publisher.Flux; import software.amazon.awssdk.regions.Region; @@ -30,10 +31,10 @@ import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatModel; import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.Generation; -import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.SystemPromptTemplate; -import org.springframework.ai.prompt.messages.Message; -import org.springframework.ai.prompt.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -51,7 +52,7 @@ public class BedrockAnthropicChatAutoConfigurationIT { .withPropertyValues("spring.ai.bedrock.anthropic.chat.enabled=true", "spring.ai.bedrock.aws.access-key=" + System.getenv("AWS_ACCESS_KEY_ID"), "spring.ai.bedrock.aws.secret-key=" + System.getenv("AWS_SECRET_ACCESS_KEY"), - "spring.ai.bedrock.anthropic.chat.model=" + AnthropicChatModel.CLAUDE_V2.id(), + "spring.ai.bedrock.anthropic.chat.generative=" + AnthropicChatModel.CLAUDE_V2.id(), "spring.ai.bedrock.aws.region=" + Region.EU_CENTRAL_1.id(), "spring.ai.bedrock.anthropic.chat.temperature=0.5", "spring.ai.bedrock.anthropic.chat.maxGenLen=500") .withConfiguration(AutoConfigurations.of(BedrockAnthropicChatAutoConfiguration.class)); @@ -70,8 +71,8 @@ public class BedrockAnthropicChatAutoConfigurationIT { public void chatCompletion() { contextRunner.run(context -> { BedrockAnthropicChatClient anthropicChatClient = context.getBean(BedrockAnthropicChatClient.class); - ChatResponse response = anthropicChatClient.generate(new Prompt(List.of(userMessage, systemMessage))); - assertThat(response.getGeneration().getContent()).contains("Blackbeard"); + ChatResponse response = anthropicChatClient.call(new Prompt(List.of(userMessage, systemMessage))); + assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); }); } @@ -88,9 +89,10 @@ public void chatCompletionStreaming() { assertThat(responses.size()).isGreaterThan(2); String stitchedResponseContent = responses.stream() - .map(ChatResponse::getGenerations) + .map(ChatResponse::getResults) .flatMap(List::stream) - .map(Generation::getContent) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) .collect(Collectors.joining()); assertThat(stitchedResponseContent).contains("Blackbeard"); @@ -103,7 +105,7 @@ public void propertiesTest() { new ApplicationContextRunner() .withPropertyValues("spring.ai.bedrock.anthropic.chat.enabled=true", "spring.ai.bedrock.aws.access-key=ACCESS_KEY", "spring.ai.bedrock.aws.secret-key=SECRET_KEY", - "spring.ai.bedrock.anthropic.chat.model=MODEL_XYZ", + "spring.ai.bedrock.anthropic.chat.generative=MODEL_XYZ", "spring.ai.bedrock.aws.region=" + Region.EU_CENTRAL_1.id(), "spring.ai.bedrock.anthropic.chat.temperature=0.55") .withConfiguration(AutoConfigurations.of(BedrockAnthropicChatAutoConfiguration.class)) diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatAutoConfigurationIT.java index f80a9b78dd1..b50347f5b05 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatAutoConfigurationIT.java @@ -23,6 +23,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.springframework.ai.chat.ChatResponse; +import org.springframework.ai.chat.messages.AssistantMessage; import reactor.core.publisher.Flux; import software.amazon.awssdk.regions.Region; @@ -32,10 +33,10 @@ import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatRequest.ReturnLikelihoods; import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatRequest.Truncate; import org.springframework.ai.chat.Generation; -import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.SystemPromptTemplate; -import org.springframework.ai.prompt.messages.Message; -import org.springframework.ai.prompt.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -54,7 +55,7 @@ public class BedrockCohereChatAutoConfigurationIT { "spring.ai.bedrock.aws.access-key=" + System.getenv("AWS_ACCESS_KEY_ID"), "spring.ai.bedrock.aws.secret-key=" + System.getenv("AWS_SECRET_ACCESS_KEY"), "spring.ai.bedrock.aws.region=" + Region.US_EAST_1.id(), - "spring.ai.bedrock.cohere.chat.model=" + CohereChatModel.COHERE_COMMAND_V14.id(), + "spring.ai.bedrock.cohere.chat.generative=" + CohereChatModel.COHERE_COMMAND_V14.id(), "spring.ai.bedrock.cohere.chat.temperature=0.5", "spring.ai.bedrock.cohere.chat.maxTokens=500") .withConfiguration(AutoConfigurations.of(BedrockCohereChatAutoConfiguration.class)); @@ -72,8 +73,8 @@ public class BedrockCohereChatAutoConfigurationIT { public void chatCompletion() { contextRunner.run(context -> { BedrockCohereChatClient cohereChatClient = context.getBean(BedrockCohereChatClient.class); - ChatResponse response = cohereChatClient.generate(new Prompt(List.of(userMessage, systemMessage))); - assertThat(response.getGeneration().getContent()).contains("Blackbeard"); + ChatResponse response = cohereChatClient.call(new Prompt(List.of(userMessage, systemMessage))); + assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); }); } @@ -90,9 +91,10 @@ public void chatCompletionStreaming() { assertThat(responses.size()).isGreaterThan(2); String stitchedResponseContent = responses.stream() - .map(ChatResponse::getGenerations) + .map(ChatResponse::getResults) .flatMap(List::stream) - .map(Generation::getContent) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) .collect(Collectors.joining()); assertThat(stitchedResponseContent).contains("Blackbeard"); @@ -104,7 +106,7 @@ public void propertiesTest() { new ApplicationContextRunner().withPropertyValues("spring.ai.bedrock.cohere.chat.enabled=true", "spring.ai.bedrock.aws.access-key=ACCESS_KEY", "spring.ai.bedrock.aws.secret-key=SECRET_KEY", - "spring.ai.bedrock.cohere.chat.model=MODEL_XYZ", + "spring.ai.bedrock.cohere.chat.generative=MODEL_XYZ", "spring.ai.bedrock.aws.region=" + Region.EU_CENTRAL_1.id(), "spring.ai.bedrock.cohere.chat.temperature=0.55", "spring.ai.bedrock.cohere.chat.topP=0.55", "spring.ai.bedrock.cohere.chat.topK=10", "spring.ai.bedrock.cohere.chat.stopSequences=END1,END2", diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/llama2/BedrockLlama2ChatAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/llama2/BedrockLlama2ChatAutoConfigurationIT.java index 7ecefc6ba74..ab43d8096b0 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/llama2/BedrockLlama2ChatAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/llama2/BedrockLlama2ChatAutoConfigurationIT.java @@ -23,6 +23,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.springframework.ai.chat.ChatResponse; +import org.springframework.ai.chat.messages.AssistantMessage; import reactor.core.publisher.Flux; import software.amazon.awssdk.regions.Region; @@ -30,10 +31,10 @@ import org.springframework.ai.bedrock.llama2.BedrockLlama2ChatClient; import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi.Llama2ChatModel; import org.springframework.ai.chat.Generation; -import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.SystemPromptTemplate; -import org.springframework.ai.prompt.messages.Message; -import org.springframework.ai.prompt.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -52,7 +53,7 @@ public class BedrockLlama2ChatAutoConfigurationIT { "spring.ai.bedrock.aws.access-key=" + System.getenv("AWS_ACCESS_KEY_ID"), "spring.ai.bedrock.aws.secret-key=" + System.getenv("AWS_SECRET_ACCESS_KEY"), "spring.ai.bedrock.aws.region=" + Region.US_EAST_1.id(), - "spring.ai.bedrock.llama2.chat.model=" + Llama2ChatModel.LLAMA2_70B_CHAT_V1.id(), + "spring.ai.bedrock.llama2.chat.generative=" + Llama2ChatModel.LLAMA2_70B_CHAT_V1.id(), "spring.ai.bedrock.llama2.chat.temperature=0.5", "spring.ai.bedrock.llama2.chat.maxGenLen=500") .withConfiguration(AutoConfigurations.of(BedrockLlama2ChatAutoConfiguration.class)); @@ -70,8 +71,8 @@ public class BedrockLlama2ChatAutoConfigurationIT { public void chatCompletion() { contextRunner.run(context -> { BedrockLlama2ChatClient llama2ChatClient = context.getBean(BedrockLlama2ChatClient.class); - ChatResponse response = llama2ChatClient.generate(new Prompt(List.of(userMessage, systemMessage))); - assertThat(response.getGeneration().getContent()).contains("Blackbeard"); + ChatResponse response = llama2ChatClient.call(new Prompt(List.of(userMessage, systemMessage))); + assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); }); } @@ -88,9 +89,10 @@ public void chatCompletionStreaming() { assertThat(responses.size()).isGreaterThan(2); String stitchedResponseContent = responses.stream() - .map(ChatResponse::getGenerations) + .map(ChatResponse::getResults) .flatMap(List::stream) - .map(Generation::getContent) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) .collect(Collectors.joining()); assertThat(stitchedResponseContent).contains("Blackbeard"); @@ -103,7 +105,7 @@ public void propertiesTest() { new ApplicationContextRunner() .withPropertyValues("spring.ai.bedrock.llama2.chat.enabled=true", "spring.ai.bedrock.aws.access-key=ACCESS_KEY", "spring.ai.bedrock.aws.secret-key=SECRET_KEY", - "spring.ai.bedrock.llama2.chat.model=MODEL_XYZ", + "spring.ai.bedrock.llama2.chat.generative=MODEL_XYZ", "spring.ai.bedrock.aws.region=" + Region.EU_CENTRAL_1.id(), "spring.ai.bedrock.llama2.chat.temperature=0.55", "spring.ai.bedrock.llama2.chat.maxGenLen=123") .withConfiguration(AutoConfigurations.of(BedrockLlama2ChatAutoConfiguration.class)) diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfigurationIT.java index 835217e60be..9d7bde70070 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfigurationIT.java @@ -23,6 +23,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.springframework.ai.chat.ChatResponse; +import org.springframework.ai.chat.messages.AssistantMessage; import reactor.core.publisher.Flux; import software.amazon.awssdk.regions.Region; @@ -30,10 +31,10 @@ import org.springframework.ai.bedrock.titan.BedrockTitanChatClient; import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatModel; import org.springframework.ai.chat.Generation; -import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.SystemPromptTemplate; -import org.springframework.ai.prompt.messages.Message; -import org.springframework.ai.prompt.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -52,7 +53,7 @@ public class BedrockTitanChatAutoConfigurationIT { "spring.ai.bedrock.aws.access-key=" + System.getenv("AWS_ACCESS_KEY_ID"), "spring.ai.bedrock.aws.secret-key=" + System.getenv("AWS_SECRET_ACCESS_KEY"), "spring.ai.bedrock.aws.region=" + Region.US_EAST_1.id(), - "spring.ai.bedrock.titan.chat.model=" + TitanChatModel.TITAN_TEXT_EXPRESS_V1.id(), + "spring.ai.bedrock.titan.chat.generative=" + TitanChatModel.TITAN_TEXT_EXPRESS_V1.id(), "spring.ai.bedrock.titan.chat.temperature=0.5", "spring.ai.bedrock.titan.chat.maxTokens=500") .withConfiguration(AutoConfigurations.of(BedrockTitanChatAutoConfiguration.class)); @@ -70,8 +71,8 @@ public class BedrockTitanChatAutoConfigurationIT { public void chatCompletion() { contextRunner.run(context -> { BedrockTitanChatClient chatClient = context.getBean(BedrockTitanChatClient.class); - ChatResponse response = chatClient.generate(new Prompt(List.of(userMessage, systemMessage))); - assertThat(response.getGeneration().getContent()).contains("Blackbeard"); + ChatResponse response = chatClient.call(new Prompt(List.of(userMessage, systemMessage))); + assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); }); } @@ -87,9 +88,10 @@ public void chatCompletionStreaming() { assertThat(responses.size()).isGreaterThan(1); String stitchedResponseContent = responses.stream() - .map(ChatResponse::getGenerations) + .map(ChatResponse::getResults) .flatMap(List::stream) - .map(Generation::getContent) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) .collect(Collectors.joining()); assertThat(stitchedResponseContent).contains("Blackbeard"); @@ -102,7 +104,7 @@ public void propertiesTest() { new ApplicationContextRunner() .withPropertyValues("spring.ai.bedrock.titan.chat.enabled=true", "spring.ai.bedrock.aws.access-key=ACCESS_KEY", "spring.ai.bedrock.aws.secret-key=SECRET_KEY", - "spring.ai.bedrock.titan.chat.model=MODEL_XYZ", + "spring.ai.bedrock.titan.chat.generative=MODEL_XYZ", "spring.ai.bedrock.aws.region=" + Region.EU_CENTRAL_1.id(), "spring.ai.bedrock.titan.chat.temperature=0.55", "spring.ai.bedrock.titan.chat.topP=0.55", "spring.ai.bedrock.titan.chat.stopSequences=END1,END2", diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfigurationIT.java index 59169990532..9ddcb68ee7e 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfigurationIT.java @@ -47,7 +47,7 @@ public class BedrockTitanEmbeddingAutoConfigurationIT { "spring.ai.bedrock.aws.access-key=" + System.getenv("AWS_ACCESS_KEY_ID"), "spring.ai.bedrock.aws.secret-key=" + System.getenv("AWS_SECRET_ACCESS_KEY"), "spring.ai.bedrock.aws.region=" + Region.US_EAST_1.id(), - "spring.ai.bedrock.titan.embedding.model=" + TitanEmbeddingModel.TITAN_EMBED_IMAGE_V1.id()) + "spring.ai.bedrock.titan.embedding.generative=" + TitanEmbeddingModel.TITAN_EMBED_IMAGE_V1.id()) .withConfiguration(AutoConfigurations.of(BedrockTitanEmbeddingAutoConfiguration.class)); @Test @@ -84,10 +84,12 @@ public void singleImageEmbedding() { @Test public void propertiesTest() { - new ApplicationContextRunner().withPropertyValues("spring.ai.bedrock.titan.embedding.enabled=true", - "spring.ai.bedrock.aws.access-key=ACCESS_KEY", "spring.ai.bedrock.aws.secret-key=SECRET_KEY", - "spring.ai.bedrock.aws.region=" + Region.EU_CENTRAL_1.id(), - "spring.ai.bedrock.titan.embedding.model=MODEL_XYZ", "spring.ai.bedrock.titan.embedding.inputType=TEXT") + new ApplicationContextRunner() + .withPropertyValues("spring.ai.bedrock.titan.embedding.enabled=true", + "spring.ai.bedrock.aws.access-key=ACCESS_KEY", "spring.ai.bedrock.aws.secret-key=SECRET_KEY", + "spring.ai.bedrock.aws.region=" + Region.EU_CENTRAL_1.id(), + "spring.ai.bedrock.titan.embedding.generative=MODEL_XYZ", + "spring.ai.bedrock.titan.embedding.inputType=TEXT") .withConfiguration(AutoConfigurations.of(BedrockTitanEmbeddingAutoConfiguration.class)) .run(context -> { var properties = context.getBean(BedrockTitanEmbeddingProperties.class); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfigurationIT.java index 2c0a7f1a1ee..45c04eb9e4e 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfigurationIT.java @@ -23,11 +23,12 @@ import org.junit.jupiter.api.Test; import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.Generation; +import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.ollama.OllamaChatClient; -import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.SystemPromptTemplate; -import org.springframework.ai.prompt.messages.Message; -import org.springframework.ai.prompt.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.testcontainers.containers.GenericContainer; @@ -61,7 +62,7 @@ public class OllamaAutoConfigurationIT { @BeforeAll public static void beforeAll() throws IOException, InterruptedException { - logger.info("Start pulling the '" + MODEL_NAME + " ' model ... would take several minutes ..."); + logger.info("Start pulling the '" + MODEL_NAME + " ' generative ... would take several minutes ..."); ollamaContainer.execInContainer("ollama", "pull", MODEL_NAME); logger.info(MODEL_NAME + " pulling competed!"); @@ -69,7 +70,7 @@ public static void beforeAll() throws IOException, InterruptedException { } private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withPropertyValues("spring.ai.ollama.chat.enabled=true", "spring.ai.ollama.chat.model=" + MODEL_NAME, + .withPropertyValues("spring.ai.ollama.chat.enabled=true", "spring.ai.ollama.chat.generative=" + MODEL_NAME, "spring.ai.ollama.baseUrl=" + baseUrl, "spring.ai.ollama.chat.temperature=0.5", "spring.ai.ollama.chat.topK=10") .withConfiguration(AutoConfigurations.of(OllamaAutoConfiguration.class)); @@ -88,8 +89,8 @@ public static void beforeAll() throws IOException, InterruptedException { public void chatCompletion() { contextRunner.run(context -> { OllamaChatClient chatClient = context.getBean(OllamaChatClient.class); - ChatResponse response = chatClient.generate(new Prompt(List.of(userMessage, systemMessage))); - assertThat(response.getGeneration().getContent()).contains("Blackbeard"); + ChatResponse response = chatClient.call(new Prompt(List.of(userMessage, systemMessage))); + assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); }); } @@ -105,9 +106,10 @@ public void chatCompletionStreaming() { assertThat(responses.size()).isGreaterThan(1); String stitchedResponseContent = responses.stream() - .map(ChatResponse::getGenerations) + .map(ChatResponse::getResults) .flatMap(List::stream) - .map(Generation::getContent) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) .collect(Collectors.joining()); assertThat(stitchedResponseContent).contains("Blackbeard"); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfigurationTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfigurationTests.java index 9ad9be0096e..37fe5f9ece1 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfigurationTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfigurationTests.java @@ -35,7 +35,7 @@ public void propertiesTest() { new ApplicationContextRunner().withPropertyValues( // @formatter:off "spring.ai.ollama.base-url=TEST_BASE_URL", - "spring.ai.ollama.chat.model=MODEL_XYZ", + "spring.ai.ollama.chat.generative=MODEL_XYZ", "spring.ai.ollama.chat.options.temperature=0.55", "spring.ai.ollama.chat.options.topP=0.56", "spring.ai.ollama.chat.options.topK=123") diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationIT.java index 51a4a60d9dc..11f49d4144b 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationIT.java @@ -53,7 +53,7 @@ public class OllamaEmbeddingAutoConfigurationIT { @BeforeAll public static void beforeAll() throws IOException, InterruptedException { - logger.info("Start pulling the '" + MODEL_NAME + " ' model ... would take several minutes ..."); + logger.info("Start pulling the '" + MODEL_NAME + " ' generative ... would take several minutes ..."); ollamaContainer.execInContainer("ollama", "pull", MODEL_NAME); logger.info(MODEL_NAME + " pulling competed!"); @@ -61,7 +61,8 @@ public static void beforeAll() throws IOException, InterruptedException { } private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withPropertyValues("spring.ai.ollama.embedding.model=" + MODEL_NAME, "spring.ai.ollama.base-url=" + baseUrl) + .withPropertyValues("spring.ai.ollama.embedding.generative=" + MODEL_NAME, + "spring.ai.ollama.base-url=" + baseUrl) .withConfiguration(AutoConfigurations.of(OllamaAutoConfiguration.class)); @Test diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationTests.java index e9976f11cf5..45fb26c76f7 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationTests.java @@ -33,7 +33,8 @@ public class OllamaEmbeddingAutoConfigurationTests { public void propertiesTest() { new ApplicationContextRunner() - .withPropertyValues("spring.ai.ollama.base-url=TEST_BASE_URL", "spring.ai.ollama.embedding.model=MODEL_XYZ", + .withPropertyValues("spring.ai.ollama.base-url=TEST_BASE_URL", + "spring.ai.ollama.embedding.generative=MODEL_XYZ", "spring.ai.ollama.embedding.options.temperature=0.13", "spring.ai.ollama.embedding.options.topK=13") .withConfiguration(AutoConfigurations.of(OllamaAutoConfiguration.class)) .run(context -> { diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfigurationIT.java index 17e8a0c5c71..2caa91ca254 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfigurationIT.java @@ -49,7 +49,7 @@ public class OpenAiAutoConfigurationIT { void generate() { contextRunner.run(context -> { OpenAiChatClient client = context.getBean(OpenAiChatClient.class); - String response = client.generate("Hello"); + String response = client.call("Hello"); assertThat(response).isNotEmpty(); logger.info("Response: " + response); }); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/transformers/TransformersEmbeddingClientAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/transformers/TransformersEmbeddingClientAutoConfigurationIT.java index 67121339cf4..8b497830a40 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/transformers/TransformersEmbeddingClientAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/transformers/TransformersEmbeddingClientAutoConfigurationIT.java @@ -45,8 +45,8 @@ public void embedding() { contextRunner.run(context -> { var properties = context.getBean(TransformersEmbeddingClientProperties.class); assertThat(properties.getCache().isEnabled()).isTrue(); - assertThat(properties.getCache().getDirectory()) - .isEqualTo(new File(System.getProperty("java.io.tmpdir"), "spring-ai-onnx-model").getAbsolutePath()); + assertThat(properties.getCache().getDirectory()).isEqualTo( + new File(System.getProperty("java.io.tmpdir"), "spring-ai-onnx-generative").getAbsolutePath()); EmbeddingClient embeddingClient = context.getBean(EmbeddingClient.class); assertThat(embeddingClient).isInstanceOf(TransformersEmbeddingClient.class); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/VertexAiAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/VertexAiAutoConfigurationIT.java index e650a68672b..44036f9743f 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/VertexAiAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/VertexAiAutoConfigurationIT.java @@ -39,8 +39,8 @@ public class VertexAiAutoConfigurationIT { private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withPropertyValues("spring.ai.vertex.ai.baseUrl=https://generativelanguage.googleapis.com/v1beta3", "spring.ai.vertex.ai.apiKey=" + System.getenv("PALM_API_KEY"), - "spring.ai.vertex.ai.chat.model=chat-bison-001", - "spring.ai.vertex.ai.embedding.model=embedding-gecko-001") + "spring.ai.vertex.ai.chat.generative=chat-bison-001", + "spring.ai.vertex.ai.embedding.generative=embedding-gecko-001") .withConfiguration(AutoConfigurations.of(VertexAiAutoConfiguration.class)); @Test @@ -48,7 +48,7 @@ void generate() { contextRunner.run(context -> { VertexAiChatClient client = context.getBean(VertexAiChatClient.class); - String response = client.generate("Hello"); + String response = client.call("Hello"); assertThat(response).isNotEmpty(); logger.info("Response: " + response); diff --git a/spring-ai-test/src/main/java/org/springframework/ai/evaluation/BasicEvaluationTest.java b/spring-ai-test/src/main/java/org/springframework/ai/evaluation/BasicEvaluationTest.java index b8b8ca23c6b..f7bf74db4b0 100644 --- a/spring-ai-test/src/main/java/org/springframework/ai/evaluation/BasicEvaluationTest.java +++ b/spring-ai-test/src/main/java/org/springframework/ai/evaluation/BasicEvaluationTest.java @@ -20,10 +20,10 @@ import org.slf4j.LoggerFactory; import org.springframework.ai.chat.ChatClient; import org.springframework.ai.chat.ChatResponse; -import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.PromptTemplate; -import org.springframework.ai.prompt.messages.Message; -import org.springframework.ai.prompt.messages.SystemMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.core.io.Resource; @@ -55,7 +55,7 @@ public class BasicEvaluationTest { protected void evaluateQuestionAndAnswer(String question, ChatResponse response, boolean factBased) { assertThat(response).isNotNull(); - String answer = response.getGeneration().getContent(); + String answer = response.getResult().getOutput().getContent(); logger.info("Question: " + question); logger.info("Answer:" + answer); PromptTemplate userPromptTemplate = new PromptTemplate(userEvaluatorResource, @@ -69,12 +69,12 @@ protected void evaluateQuestionAndAnswer(String question, ChatResponse response, } Message userMessage = userPromptTemplate.createMessage(); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - String yesOrNo = openAiChatClient.generate(prompt).getGeneration().getContent(); + String yesOrNo = openAiChatClient.call(prompt).getResult().getOutput().getContent(); logger.info("Is Answer related to question: " + yesOrNo); if (yesOrNo.equalsIgnoreCase("no")) { SystemMessage notRelatedSystemMessage = new SystemMessage(qaEvaluatorNotRelatedResource); prompt = new Prompt(List.of(userMessage, notRelatedSystemMessage)); - String reasonForFailure = openAiChatClient.generate(prompt).getGeneration().getContent(); + String reasonForFailure = openAiChatClient.call(prompt).getResult().getOutput().getContent(); fail(reasonForFailure); } else { From 415e6402e29a1e51f8b3236a3058007a2add11d5 Mon Sep 17 00:00:00 2001 From: Mark Pollack Date: Tue, 23 Jan 2024 22:31:14 -0500 Subject: [PATCH 02/11] fix compile errors --- .../ai/autoconfigure/openai/OpenAiAutoConfigurationIT.java | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfigurationIT.java index 2caa91ca254..748bc822a5d 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfigurationIT.java @@ -23,14 +23,14 @@ import org.apache.commons.logging.LogFactory; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; import reactor.core.publisher.Flux; import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.openai.OpenAiChatClient; import org.springframework.ai.openai.OpenAiEmbeddingClient; -import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.messages.UserMessage; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -60,9 +60,8 @@ void generateStreaming() { contextRunner.run(context -> { OpenAiChatClient client = context.getBean(OpenAiChatClient.class); Flux responseFlux = client.generateStream(new Prompt(new UserMessage("Hello"))); - String response = responseFlux.collectList().block().stream().map(chatResponse -> { - return chatResponse.getGenerations().get(0).getContent(); + return chatResponse.getResults().get(0).getOutput().getContent(); }).collect(Collectors.joining()); assertThat(response).isNotEmpty(); From 7087d2e5764b7a089aa6de7421bc31284e0d0173 Mon Sep 17 00:00:00 2001 From: Mark Pollack Date: Tue, 23 Jan 2024 22:33:15 -0500 Subject: [PATCH 03/11] In ModelResult, change method from getResultMetadata to getMetadata for better readability --- .../openai/metadata/AzureOpenAiChatClientMetadataTests.java | 2 +- .../chat/OpenAiChatClientWithChatResponseMetadataTests.java | 2 +- .../src/main/java/org/springframework/ai/chat/Generation.java | 2 +- .../main/java/org/springframework/ai/image/ImageGeneration.java | 2 +- .../src/main/java/org/springframework/ai/model/ModelResult.java | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiChatClientMetadataTests.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiChatClientMetadataTests.java index 1cf366efa07..8974540dcd9 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiChatClientMetadataTests.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiChatClientMetadataTests.java @@ -124,7 +124,7 @@ private void assertGenerationMetadata(ChatResponse response) { private void assertChoiceMetadata(Generation generation) { - ChatGenerationMetadata chatGenerationMetadata = generation.getResultMetadata(); + ChatGenerationMetadata chatGenerationMetadata = generation.getMetadata(); assertThat(chatGenerationMetadata).isNotNull(); assertThat(chatGenerationMetadata.getFinishReason()).isEqualTo("stop"); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatClientWithChatResponseMetadataTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatClientWithChatResponseMetadataTests.java index 8c2b494d9c1..4d34cfbd3bc 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatClientWithChatResponseMetadataTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatClientWithChatResponseMetadataTests.java @@ -111,7 +111,7 @@ void aiResponseContainsAiMetadata() { assertThat(promptMetadata).isEmpty(); response.getResults().forEach(generation -> { - ChatGenerationMetadata chatGenerationMetadata = generation.getResultMetadata(); + ChatGenerationMetadata chatGenerationMetadata = generation.getMetadata(); assertThat(chatGenerationMetadata).isNotNull(); assertThat(chatGenerationMetadata.getFinishReason()).isEqualTo("STOP"); assertThat(chatGenerationMetadata.getContentFilterMetadata()).isNull(); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/Generation.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/Generation.java index 5df07190ff1..6234a0d562b 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/Generation.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/Generation.java @@ -46,7 +46,7 @@ public AssistantMessage getOutput() { return this.assistantMessage; } - public ChatGenerationMetadata getResultMetadata() { + public ChatGenerationMetadata getMetadata() { ChatGenerationMetadata chatGenerationMetadata = this.chatGenerationMetadata; return chatGenerationMetadata != null ? chatGenerationMetadata : ChatGenerationMetadata.NULL; } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageGeneration.java b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageGeneration.java index 98cb0654fbf..3ec55f62271 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageGeneration.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageGeneration.java @@ -23,7 +23,7 @@ public Image getOutput() { } @Override - public ImageGenerationMetadata getResultMetadata() { + public ImageGenerationMetadata getMetadata() { return imageGenerationMetadata; } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelResult.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelResult.java index 8fba7ac7a3e..3e14c80617e 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelResult.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelResult.java @@ -22,6 +22,6 @@ public interface ModelResult { * Retrieves the metadata associated with the result of an AI model. * @return the metadata associated with the result */ - ResultMetadata getResultMetadata(); + ResultMetadata getMetadata(); } From e2ed8f5797aa3972fd08828d78b2fb8b5bac905b Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Wed, 24 Jan 2024 15:41:13 +0100 Subject: [PATCH 04/11] Add missing code license blocks --- .../ai/openai/OpenAiEmbeddingClient.java | 1 + .../ai/openai/OpenAiImageClient.java | 16 ++++++++++++++ .../OpenAiImageGenerationMetadata.java | 16 ++++++++++++++ .../metadata/OpenAiImageResponseMetadata.java | 16 ++++++++++++++ ...eClientWithImageResponseMetadataTests.java | 2 +- .../springframework/ai/chat/ChatOptions.java | 16 ++++++++++++++ .../ai/chat/ChatOptionsBuilder.java | 16 ++++++++++++++ .../org/springframework/ai/image/Image.java | 16 ++++++++++++++ .../springframework/ai/image/ImageClient.java | 16 ++++++++++++++ .../ai/image/ImageGeneration.java | 16 ++++++++++++++ .../ai/image/ImageGenerationMetadata.java | 16 ++++++++++++++ .../ai/image/ImageMessage.java | 16 ++++++++++++++ .../ai/image/ImageOptions.java | 16 ++++++++++++++ .../ai/image/ImageOptionsBuilder.java | 16 ++++++++++++++ .../springframework/ai/image/ImagePrompt.java | 16 ++++++++++++++ .../ai/image/ImageResponse.java | 21 ++++++++++++++++--- .../ai/image/ImageResponseMetadata.java | 16 ++++++++++++++ .../ai/image/NewImageClient.java | 18 ++++++++++++++-- .../springframework/ai/model/ModelClient.java | 18 ++++++++++++++-- .../ai/model/ModelOptions.java | 16 ++++++++++++++ .../ai/model/ModelOptionsUtils.java | 16 ++++++++++++++ .../ai/model/ModelRequest.java | 16 ++++++++++++++ .../ai/model/ModelResponse.java | 18 ++++++++++++++-- .../springframework/ai/model/ModelResult.java | 16 ++++++++++++++ .../ai/model/ResponseMetadata.java | 16 ++++++++++++++ .../ai/model/ResultMetadata.java | 16 ++++++++++++++ 26 files changed, 388 insertions(+), 10 deletions(-) diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingClient.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingClient.java index 5524472b1a7..627c476902f 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingClient.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingClient.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.openai; import java.time.Duration; diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageClient.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageClient.java index db829b18369..88b4e42056e 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageClient.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageClient.java @@ -1,3 +1,19 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.openai; import org.slf4j.Logger; diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiImageGenerationMetadata.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiImageGenerationMetadata.java index 6929d85c3d0..8e9a5d8b0ac 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiImageGenerationMetadata.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiImageGenerationMetadata.java @@ -1,3 +1,19 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.openai.metadata; import org.springframework.ai.image.ImageGenerationMetadata; diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiImageResponseMetadata.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiImageResponseMetadata.java index 996c50e2f5b..b954affa5f6 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiImageResponseMetadata.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiImageResponseMetadata.java @@ -1,3 +1,19 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.openai.metadata; import org.springframework.ai.image.ImageResponseMetadata; diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageClientWithImageResponseMetadataTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageClientWithImageResponseMetadataTests.java index fa6aa0a3762..5f660a1f8b0 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageClientWithImageResponseMetadataTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageClientWithImageResponseMetadataTests.java @@ -63,7 +63,7 @@ void resetMockServer() { } @Test - void aiResponseContainsImageRespoinseMetadata() { + void aiResponseContainsImageResponseMetadata() { prepareMock(); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatOptions.java index 84d9f2465d2..f6cd9053159 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatOptions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatOptions.java @@ -1,3 +1,19 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.chat; import org.springframework.ai.model.ModelOptions; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatOptionsBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatOptionsBuilder.java index 33866b3bf02..60c72987215 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatOptionsBuilder.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatOptionsBuilder.java @@ -1,3 +1,19 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.chat; public class ChatOptionsBuilder { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/Image.java b/spring-ai-core/src/main/java/org/springframework/ai/image/Image.java index 561d7398a37..fa1f0b8ff9d 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/image/Image.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/Image.java @@ -1,3 +1,19 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.image; import java.util.Objects; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageClient.java b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageClient.java index c12db623ef0..bf06964e1ff 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageClient.java @@ -1,3 +1,19 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.image; import org.springframework.ai.model.ModelClient; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageGeneration.java b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageGeneration.java index 3ec55f62271..94f739266a3 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageGeneration.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageGeneration.java @@ -1,3 +1,19 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.image; import org.springframework.ai.model.ModelResult; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageGenerationMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageGenerationMetadata.java index 7be56cfc959..e140aa814be 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageGenerationMetadata.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageGenerationMetadata.java @@ -1,3 +1,19 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.image; import org.springframework.ai.model.ResultMetadata; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageMessage.java b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageMessage.java index 153b606df5c..51d378b8c32 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageMessage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageMessage.java @@ -1,3 +1,19 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.image; import java.util.Objects; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageOptions.java index 5d8ededb006..dbfec79c9d6 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageOptions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageOptions.java @@ -1,3 +1,19 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.image; import org.springframework.ai.model.ModelOptions; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageOptionsBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageOptionsBuilder.java index f5e56d96b9c..49dc3497d3e 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageOptionsBuilder.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageOptionsBuilder.java @@ -1,3 +1,19 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.image; public class ImageOptionsBuilder { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/ImagePrompt.java b/spring-ai-core/src/main/java/org/springframework/ai/image/ImagePrompt.java index 99589894c0a..5ea58bea469 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/image/ImagePrompt.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/ImagePrompt.java @@ -1,3 +1,19 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.image; import org.springframework.ai.model.ModelRequest; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageResponse.java b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageResponse.java index 84c4e93ccdd..ad6cda7c9e7 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageResponse.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageResponse.java @@ -1,11 +1,26 @@ -package org.springframework.ai.image; +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ -import org.springframework.ai.model.ModelResponse; -import org.springframework.ai.model.ResponseMetadata; +package org.springframework.ai.image; import java.util.List; import java.util.Objects; +import org.springframework.ai.model.ModelResponse; + public class ImageResponse implements ModelResponse { private final ImageResponseMetadata imageResponseMetadata; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageResponseMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageResponseMetadata.java index 512303ca490..7378fedca6e 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageResponseMetadata.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageResponseMetadata.java @@ -1,3 +1,19 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.image; import org.springframework.ai.model.ResponseMetadata; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/NewImageClient.java b/spring-ai-core/src/main/java/org/springframework/ai/image/NewImageClient.java index 579e8e67376..d5bce3dd333 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/image/NewImageClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/NewImageClient.java @@ -1,9 +1,23 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.image; import org.springframework.ai.model.ModelClient; -import java.util.List; - @FunctionalInterface public interface NewImageClient extends ModelClient { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelClient.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelClient.java index 2d3d2f3bf95..38de2f149b7 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelClient.java @@ -1,3 +1,19 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.model; /** @@ -17,8 +33,6 @@ public interface ModelClient, TRes extends ModelRes /** * Executes a method call to the AI model. * @param request the request object to be sent to the AI model - * @param the generic type of the request object - * @param the generic type of the response from the AI model * @return the response from the AI model */ TRes call(TReq request); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptions.java index ae01e7a770b..9b6e908f4da 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptions.java @@ -1,3 +1,19 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.model; /** diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java index b84aa32cb2a..5b0e8ce7c7a 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java @@ -1,3 +1,19 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.model; import com.fasterxml.jackson.core.JsonProcessingException; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelRequest.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelRequest.java index 6019f3144c9..0aac6da82c7 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelRequest.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelRequest.java @@ -1,3 +1,19 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.model; /** diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelResponse.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelResponse.java index 32143e93390..5c8a17b5827 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelResponse.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelResponse.java @@ -1,3 +1,19 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.model; import java.util.List; @@ -17,14 +33,12 @@ public interface ModelResponse> { /** * Retrieves the result of the AI model. - * @param the type of the result generated by the AI model * @return the result generated by the AI model */ T getResult(); /** * Retrieves the list of generated outputs by the AI model. - * @param the type of the generated outputs * @return the list of generated outputs */ List getResults(); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelResult.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelResult.java index 3e14c80617e..5a5613a7280 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelResult.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelResult.java @@ -1,3 +1,19 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.model; /** diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ResponseMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ResponseMetadata.java index b824c4c141d..14af864bb49 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/ResponseMetadata.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ResponseMetadata.java @@ -1,3 +1,19 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.model; /** diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ResultMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ResultMetadata.java index 54b6061f6f4..78d5f7f6a91 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/ResultMetadata.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ResultMetadata.java @@ -1,3 +1,19 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.model; /** From 50ead1931025888711225802207bb7ec83f6587c Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Wed, 24 Jan 2024 15:43:33 +0100 Subject: [PATCH 05/11] Add missing code license blocks2 --- .../ai/openai/image/OpenAiImageClientIT.java | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageClientIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageClientIT.java index 389f017be99..3f628439c83 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageClientIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageClientIT.java @@ -1,3 +1,18 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.springframework.ai.openai.image; import org.junit.jupiter.api.Test; From 00521e0b00235724c6de289abf9f586851cdfeb9 Mon Sep 17 00:00:00 2001 From: Mark Pollack Date: Wed, 24 Jan 2024 10:34:16 -0500 Subject: [PATCH 06/11] Fix failing tests and add null check to objectToMap --- .../ai/openai/OpenAiImageClient.java | 10 ++++++-- .../ai/openai/image/OpenAiImageClientIT.java | 24 ++++++++++++++++--- .../ai/model/ModelOptionsUtils.java | 4 ++++ 3 files changed, 33 insertions(+), 5 deletions(-) diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageClient.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageClient.java index 88b4e42056e..87b6a2fcde3 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageClient.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageClient.java @@ -70,9 +70,15 @@ public ImageResponse call(ImagePrompt imagePrompt) { // data // types to the data types used in OpenAiImageApi String instructions = imagePrompt.getInstructions().get(0).getText(); + String size; + if (imageOptionsToUse.getWidth() != null && imageOptionsToUse.getHeight() != null) { + size = imageOptionsToUse.getWidth() + "x" + imageOptionsToUse.getHeight(); + } + else { + size = null; + } OpenAiImageApi.OpenAiImageRequest openAiImageRequest = new OpenAiImageApi.OpenAiImageRequest(instructions, - imageOptionsToUse.getModel(), imageOptionsToUse.getN(), imageOptionsToUse.getQuality(), - imageOptionsToUse.getWidth() + "x" + imageOptionsToUse.getHeight(), + imageOptionsToUse.getModel(), imageOptionsToUse.getN(), imageOptionsToUse.getQuality(), size, imageOptionsToUse.getResponseFormat(), imageOptionsToUse.getStyle(), imageOptionsToUse.getUser()); // Make the request diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageClientIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageClientIT.java index 3f628439c83..be3f31b2e67 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageClientIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageClientIT.java @@ -15,14 +15,17 @@ */ package org.springframework.ai.openai.image; +import org.assertj.core.api.Assertions; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; -import org.springframework.ai.image.ImagePrompt; -import org.springframework.ai.image.ImageResponse; +import org.springframework.ai.image.*; import org.springframework.ai.openai.OpenAiTestConfiguration; +import org.springframework.ai.openai.metadata.OpenAiImageGenerationMetadata; import org.springframework.ai.openai.testutils.AbstractIT; import org.springframework.boot.test.context.SpringBootTest; +import static org.assertj.core.api.Assertions.assertThat; + @SpringBootTest(classes = OpenAiTestConfiguration.class) @EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+") public class OpenAiImageClientIT extends AbstractIT { @@ -33,8 +36,23 @@ void imageAsUrlTest() { ImageResponse imageResponse = openaiImageClient.call(imagePrompt); - System.out.println(imageResponse); + assertThat(imageResponse.getResults()).hasSize(1); + + ImageResponseMetadata imageResponseMetadata = imageResponse.getMetadata(); + assertThat(imageResponseMetadata.created()).isPositive(); + + var generation = imageResponse.getResult(); + Image image = generation.getOutput(); + assertThat(image.getUrl()).isNotEmpty(); + assertThat(image.getB64Json()).isNull(); + + var imageGenerationMetadata = generation.getMetadata(); + Assertions.assertThat(imageGenerationMetadata).isInstanceOf(OpenAiImageGenerationMetadata.class); + + OpenAiImageGenerationMetadata openAiImageGenerationMetadata = (OpenAiImageGenerationMetadata) imageGenerationMetadata; + assertThat(openAiImageGenerationMetadata).isNotNull(); + assertThat(openAiImageGenerationMetadata.getRevisedPrompt()).isNull(); } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java index 5b0e8ce7c7a..85c47554ef7 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java @@ -20,6 +20,7 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; +import java.util.HashMap; import java.util.Map; import java.util.stream.Collectors; @@ -58,6 +59,9 @@ public static T merge(Object source, Object target, Class clazz) { * @return the converted Map. */ public static Map objectToMap(Object source) { + if (source == null) { + return new HashMap<>(); + } try { String json = OBJECT_MAPPER.writeValueAsString(source); return OBJECT_MAPPER.readValue(json, new TypeReference>() { From ed9a4ccabd936cf4ae2677216cdc4023ac1064c6 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Wed, 24 Jan 2024 17:16:20 +0100 Subject: [PATCH 07/11] Fix OpenAIImageClient 4xx error handling. Remove redundant classes --- .../ai/openai/api/OpenAiImageApi.java | 56 ++++++------------- .../ai/openai/image/OpenAiImageClientIT.java | 4 +- .../ai/image/NewImageClient.java | 26 --------- 3 files changed, 19 insertions(+), 67 deletions(-) delete mode 100644 spring-ai-core/src/main/java/org/springframework/ai/image/NewImageClient.java diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageApi.java index cf02c33a8ef..5a2800087b4 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageApi.java @@ -1,10 +1,18 @@ package org.springframework.ai.openai.api; +import java.io.IOException; +import java.util.List; +import java.util.Objects; +import java.util.function.Consumer; + import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.databind.DeserializationFeature; import com.fasterxml.jackson.databind.ObjectMapper; -import org.springframework.ai.image.ImageResponse; + +import org.springframework.ai.openai.api.OpenAiApi.OpenAiApiClientErrorException; +import org.springframework.ai.openai.api.OpenAiApi.OpenAiApiException; +import org.springframework.ai.openai.api.OpenAiApi.ResponseError; import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; @@ -12,12 +20,6 @@ import org.springframework.util.Assert; import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; -import org.springframework.web.reactive.function.client.WebClient; - -import java.io.IOException; -import java.util.List; -import java.util.Objects; -import java.util.function.Consumer; public class OpenAiImageApi { @@ -28,8 +30,6 @@ public class OpenAiImageApi { // Assuming RestClient and WebClient are properly defined somewhere private final RestClient restClient; - private final WebClient webClient; - private final ObjectMapper objectMapper; /** @@ -59,8 +59,13 @@ public boolean hasError(ClientHttpResponse response) throws IOException { @Override public void handleError(ClientHttpResponse response) throws IOException { if (response.getStatusCode().isError()) { - throw new OpenAiApi.OpenAiApiException(String.format("%s - %s", response.getStatusCode().value(), - new ObjectMapper().readValue(response.getBody(), OpenAiApi.ResponseError.class))); + if (response.getStatusCode().is4xxClientError()) { + throw new OpenAiApiClientErrorException(String.format("%s - %s", + response.getStatusCode().value(), + OpenAiImageApi.this.objectMapper.readValue(response.getBody(), ResponseError.class))); + } + throw new OpenAiApiException(String.format("%s - %s", response.getStatusCode().value(), + OpenAiImageApi.this.objectMapper.readValue(response.getBody(), ResponseError.class))); } } }; @@ -69,35 +74,6 @@ public void handleError(ClientHttpResponse response) throws IOException { .defaultHeaders(jsonContentHeaders) .defaultStatusHandler(responseErrorHandler) .build(); - - this.webClient = WebClient.builder().baseUrl(baseUrl).defaultHeaders(jsonContentHeaders).build(); - } - - public static class OpenAiImageApiException extends RuntimeException { - - public OpenAiImageApiException(String message) { - super(message); - } - - public OpenAiImageApiException(String message, Throwable cause) { - super(message, cause); - } - - } - - @JsonInclude(JsonInclude.Include.NON_NULL) - public static class ImageResponseError { - - private final Error error; - - public ImageResponseError(@JsonProperty("error") Error error) { - this.error = error; - } - - public Error getError() { - return error; - } - } @JsonInclude(JsonInclude.Include.NON_NULL) diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageClientIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageClientIT.java index be3f31b2e67..bcc29dc3d66 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageClientIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageClientIT.java @@ -32,7 +32,9 @@ public class OpenAiImageClientIT extends AbstractIT { @Test void imageAsUrlTest() { - ImagePrompt imagePrompt = new ImagePrompt("Create an image of a mini golden doodle dog."); + var options = ImageOptionsBuilder.builder().withHeight(256).withWidth(256).build(); + + ImagePrompt imagePrompt = new ImagePrompt("Create an image of a mini golden doodle dog.", options); ImageResponse imageResponse = openaiImageClient.call(imagePrompt); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/NewImageClient.java b/spring-ai-core/src/main/java/org/springframework/ai/image/NewImageClient.java deleted file mode 100644 index d5bce3dd333..00000000000 --- a/spring-ai-core/src/main/java/org/springframework/ai/image/NewImageClient.java +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.image; - -import org.springframework.ai.model.ModelClient; - -@FunctionalInterface -public interface NewImageClient extends ModelClient { - - ImageResponse call(ImagePrompt request); - -} From 3199dc76e9fa1eab08888eccd4a0c3ee3bb2bdd0 Mon Sep 17 00:00:00 2001 From: Mark Pollack Date: Wed, 24 Jan 2024 12:58:30 -0500 Subject: [PATCH 08/11] fix AzureOpenAiAutoConfigurationPropertyTests --- .../azure/AzureOpenAiAutoConfigurationPropertyTests.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationPropertyTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationPropertyTests.java index f75fff6507a..e3f778b0af7 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationPropertyTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationPropertyTests.java @@ -40,7 +40,7 @@ public void chatPropertiesTest() { // @formatter:off "spring.ai.azure.openai.api-key=TEST_API_KEY", "spring.ai.azure.openai.endpoint=TEST_ENDPOINT", - "spring.ai.azure.openai.chat.generative=MODEL_XYZ", + "spring.ai.azure.openai.chat.model=MODEL_XYZ", "spring.ai.azure.openai.chat.temperature=0.55", "spring.ai.azure.openai.chat.topP=0.56", "spring.ai.azure.openai.chat.maxTokens=123") From 53897276000b27be052d18d45a7c8d195bbb2158 Mon Sep 17 00:00:00 2001 From: Mark Pollack Date: Wed, 24 Jan 2024 13:07:41 -0500 Subject: [PATCH 09/11] Fix OllamaEmbeddingAutoConfigurationTests --- .../ollama/OllamaEmbeddingAutoConfigurationTests.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationTests.java index 45fb26c76f7..e9976f11cf5 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationTests.java @@ -33,8 +33,7 @@ public class OllamaEmbeddingAutoConfigurationTests { public void propertiesTest() { new ApplicationContextRunner() - .withPropertyValues("spring.ai.ollama.base-url=TEST_BASE_URL", - "spring.ai.ollama.embedding.generative=MODEL_XYZ", + .withPropertyValues("spring.ai.ollama.base-url=TEST_BASE_URL", "spring.ai.ollama.embedding.model=MODEL_XYZ", "spring.ai.ollama.embedding.options.temperature=0.13", "spring.ai.ollama.embedding.options.topK=13") .withConfiguration(AutoConfigurations.of(OllamaAutoConfiguration.class)) .run(context -> { From 740d1755d22fd41d349820001bbfe19e6924209d Mon Sep 17 00:00:00 2001 From: Mark Pollack Date: Wed, 24 Jan 2024 13:16:36 -0500 Subject: [PATCH 10/11] Fix embeddingPropertiesTest --- .../azure/AzureOpenAiAutoConfigurationPropertyTests.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationPropertyTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationPropertyTests.java index e3f778b0af7..5a77074b8e1 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationPropertyTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationPropertyTests.java @@ -66,8 +66,7 @@ public void embeddingPropertiesTest() { new ApplicationContextRunner() .withPropertyValues("spring.ai.azure.openai.api-key=TEST_API_KEY", - "spring.ai.azure.openai.endpoint=TEST_ENDPOINT", - "spring.ai.azure.openai.embedding.generative=MODEL_XYZ") + "spring.ai.azure.openai.endpoint=TEST_ENDPOINT", "spring.ai.azure.openai.embedding.model=MODEL_XYZ") .withConfiguration(AutoConfigurations.of(AzureOpenAiAutoConfiguration.class)) .run(context -> { var chatProperties = context.getBean(AzureOpenAiEmbeddingProperties.class); From 14f3763f74f5b15cf583fba7af2e47151cb4391e Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Wed, 24 Jan 2024 19:21:52 +0100 Subject: [PATCH 11/11] Fix generative to model in property key --- .../azure/AzureOpenAiAutoConfigurationIT.java | 4 ++-- .../BedrockAnthropicChatAutoConfigurationIT.java | 4 ++-- .../cohere/BedrockCohereChatAutoConfigurationIT.java | 4 ++-- .../llama2/BedrockLlama2ChatAutoConfigurationIT.java | 4 ++-- .../titan/BedrockTitanChatAutoConfigurationIT.java | 4 ++-- .../BedrockTitanEmbeddingAutoConfigurationIT.java | 12 +++++------- .../ollama/OllamaAutoConfigurationIT.java | 2 +- .../ollama/OllamaAutoConfigurationTests.java | 4 ++-- .../ollama/OllamaEmbeddingAutoConfigurationIT.java | 3 +-- .../OllamaEmbeddingAutoConfigurationTests.java | 2 +- .../vertexai/VertexAiAutoConfigurationIT.java | 4 ++-- 11 files changed, 22 insertions(+), 25 deletions(-) diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java index 4fe2cf7c9c3..4184358ab7d 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java @@ -57,11 +57,11 @@ public class AzureOpenAiAutoConfigurationIT { "spring.ai.azure.openai.api-key=" + System.getenv("AZURE_OPENAI_API_KEY"), "spring.ai.azure.openai.endpoint=" + System.getenv("AZURE_OPENAI_ENDPOINT"), - "spring.ai.azure.openai.chat.generative=" + CHAT_MODEL_NAME, + "spring.ai.azure.openai.chat.model=" + CHAT_MODEL_NAME, "spring.ai.azure.openai.chat.temperature=0.8", "spring.ai.azure.openai.chat.maxTokens=123", - "spring.ai.azure.openai.embedding.generative=" + EMBEDDING_MODEL_NAME + "spring.ai.azure.openai.embedding.model=" + EMBEDDING_MODEL_NAME // @formatter:on ).withConfiguration(AutoConfigurations.of(AzureOpenAiAutoConfiguration.class)); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatAutoConfigurationIT.java index 9e525b62f72..8f2a2eaff84 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatAutoConfigurationIT.java @@ -52,7 +52,7 @@ public class BedrockAnthropicChatAutoConfigurationIT { .withPropertyValues("spring.ai.bedrock.anthropic.chat.enabled=true", "spring.ai.bedrock.aws.access-key=" + System.getenv("AWS_ACCESS_KEY_ID"), "spring.ai.bedrock.aws.secret-key=" + System.getenv("AWS_SECRET_ACCESS_KEY"), - "spring.ai.bedrock.anthropic.chat.generative=" + AnthropicChatModel.CLAUDE_V2.id(), + "spring.ai.bedrock.anthropic.chat.model=" + AnthropicChatModel.CLAUDE_V2.id(), "spring.ai.bedrock.aws.region=" + Region.EU_CENTRAL_1.id(), "spring.ai.bedrock.anthropic.chat.temperature=0.5", "spring.ai.bedrock.anthropic.chat.maxGenLen=500") .withConfiguration(AutoConfigurations.of(BedrockAnthropicChatAutoConfiguration.class)); @@ -105,7 +105,7 @@ public void propertiesTest() { new ApplicationContextRunner() .withPropertyValues("spring.ai.bedrock.anthropic.chat.enabled=true", "spring.ai.bedrock.aws.access-key=ACCESS_KEY", "spring.ai.bedrock.aws.secret-key=SECRET_KEY", - "spring.ai.bedrock.anthropic.chat.generative=MODEL_XYZ", + "spring.ai.bedrock.anthropic.chat.model=MODEL_XYZ", "spring.ai.bedrock.aws.region=" + Region.EU_CENTRAL_1.id(), "spring.ai.bedrock.anthropic.chat.temperature=0.55") .withConfiguration(AutoConfigurations.of(BedrockAnthropicChatAutoConfiguration.class)) diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatAutoConfigurationIT.java index b50347f5b05..b40d337afba 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatAutoConfigurationIT.java @@ -55,7 +55,7 @@ public class BedrockCohereChatAutoConfigurationIT { "spring.ai.bedrock.aws.access-key=" + System.getenv("AWS_ACCESS_KEY_ID"), "spring.ai.bedrock.aws.secret-key=" + System.getenv("AWS_SECRET_ACCESS_KEY"), "spring.ai.bedrock.aws.region=" + Region.US_EAST_1.id(), - "spring.ai.bedrock.cohere.chat.generative=" + CohereChatModel.COHERE_COMMAND_V14.id(), + "spring.ai.bedrock.cohere.chat.model=" + CohereChatModel.COHERE_COMMAND_V14.id(), "spring.ai.bedrock.cohere.chat.temperature=0.5", "spring.ai.bedrock.cohere.chat.maxTokens=500") .withConfiguration(AutoConfigurations.of(BedrockCohereChatAutoConfiguration.class)); @@ -106,7 +106,7 @@ public void propertiesTest() { new ApplicationContextRunner().withPropertyValues("spring.ai.bedrock.cohere.chat.enabled=true", "spring.ai.bedrock.aws.access-key=ACCESS_KEY", "spring.ai.bedrock.aws.secret-key=SECRET_KEY", - "spring.ai.bedrock.cohere.chat.generative=MODEL_XYZ", + "spring.ai.bedrock.cohere.chat.model=MODEL_XYZ", "spring.ai.bedrock.aws.region=" + Region.EU_CENTRAL_1.id(), "spring.ai.bedrock.cohere.chat.temperature=0.55", "spring.ai.bedrock.cohere.chat.topP=0.55", "spring.ai.bedrock.cohere.chat.topK=10", "spring.ai.bedrock.cohere.chat.stopSequences=END1,END2", diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/llama2/BedrockLlama2ChatAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/llama2/BedrockLlama2ChatAutoConfigurationIT.java index ab43d8096b0..97435c61d50 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/llama2/BedrockLlama2ChatAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/llama2/BedrockLlama2ChatAutoConfigurationIT.java @@ -53,7 +53,7 @@ public class BedrockLlama2ChatAutoConfigurationIT { "spring.ai.bedrock.aws.access-key=" + System.getenv("AWS_ACCESS_KEY_ID"), "spring.ai.bedrock.aws.secret-key=" + System.getenv("AWS_SECRET_ACCESS_KEY"), "spring.ai.bedrock.aws.region=" + Region.US_EAST_1.id(), - "spring.ai.bedrock.llama2.chat.generative=" + Llama2ChatModel.LLAMA2_70B_CHAT_V1.id(), + "spring.ai.bedrock.llama2.chat.model=" + Llama2ChatModel.LLAMA2_70B_CHAT_V1.id(), "spring.ai.bedrock.llama2.chat.temperature=0.5", "spring.ai.bedrock.llama2.chat.maxGenLen=500") .withConfiguration(AutoConfigurations.of(BedrockLlama2ChatAutoConfiguration.class)); @@ -105,7 +105,7 @@ public void propertiesTest() { new ApplicationContextRunner() .withPropertyValues("spring.ai.bedrock.llama2.chat.enabled=true", "spring.ai.bedrock.aws.access-key=ACCESS_KEY", "spring.ai.bedrock.aws.secret-key=SECRET_KEY", - "spring.ai.bedrock.llama2.chat.generative=MODEL_XYZ", + "spring.ai.bedrock.llama2.chat.model=MODEL_XYZ", "spring.ai.bedrock.aws.region=" + Region.EU_CENTRAL_1.id(), "spring.ai.bedrock.llama2.chat.temperature=0.55", "spring.ai.bedrock.llama2.chat.maxGenLen=123") .withConfiguration(AutoConfigurations.of(BedrockLlama2ChatAutoConfiguration.class)) diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfigurationIT.java index 9d7bde70070..44e7ef97503 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfigurationIT.java @@ -53,7 +53,7 @@ public class BedrockTitanChatAutoConfigurationIT { "spring.ai.bedrock.aws.access-key=" + System.getenv("AWS_ACCESS_KEY_ID"), "spring.ai.bedrock.aws.secret-key=" + System.getenv("AWS_SECRET_ACCESS_KEY"), "spring.ai.bedrock.aws.region=" + Region.US_EAST_1.id(), - "spring.ai.bedrock.titan.chat.generative=" + TitanChatModel.TITAN_TEXT_EXPRESS_V1.id(), + "spring.ai.bedrock.titan.chat.model=" + TitanChatModel.TITAN_TEXT_EXPRESS_V1.id(), "spring.ai.bedrock.titan.chat.temperature=0.5", "spring.ai.bedrock.titan.chat.maxTokens=500") .withConfiguration(AutoConfigurations.of(BedrockTitanChatAutoConfiguration.class)); @@ -104,7 +104,7 @@ public void propertiesTest() { new ApplicationContextRunner() .withPropertyValues("spring.ai.bedrock.titan.chat.enabled=true", "spring.ai.bedrock.aws.access-key=ACCESS_KEY", "spring.ai.bedrock.aws.secret-key=SECRET_KEY", - "spring.ai.bedrock.titan.chat.generative=MODEL_XYZ", + "spring.ai.bedrock.titan.chat.model=MODEL_XYZ", "spring.ai.bedrock.aws.region=" + Region.EU_CENTRAL_1.id(), "spring.ai.bedrock.titan.chat.temperature=0.55", "spring.ai.bedrock.titan.chat.topP=0.55", "spring.ai.bedrock.titan.chat.stopSequences=END1,END2", diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfigurationIT.java index 9ddcb68ee7e..59169990532 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfigurationIT.java @@ -47,7 +47,7 @@ public class BedrockTitanEmbeddingAutoConfigurationIT { "spring.ai.bedrock.aws.access-key=" + System.getenv("AWS_ACCESS_KEY_ID"), "spring.ai.bedrock.aws.secret-key=" + System.getenv("AWS_SECRET_ACCESS_KEY"), "spring.ai.bedrock.aws.region=" + Region.US_EAST_1.id(), - "spring.ai.bedrock.titan.embedding.generative=" + TitanEmbeddingModel.TITAN_EMBED_IMAGE_V1.id()) + "spring.ai.bedrock.titan.embedding.model=" + TitanEmbeddingModel.TITAN_EMBED_IMAGE_V1.id()) .withConfiguration(AutoConfigurations.of(BedrockTitanEmbeddingAutoConfiguration.class)); @Test @@ -84,12 +84,10 @@ public void singleImageEmbedding() { @Test public void propertiesTest() { - new ApplicationContextRunner() - .withPropertyValues("spring.ai.bedrock.titan.embedding.enabled=true", - "spring.ai.bedrock.aws.access-key=ACCESS_KEY", "spring.ai.bedrock.aws.secret-key=SECRET_KEY", - "spring.ai.bedrock.aws.region=" + Region.EU_CENTRAL_1.id(), - "spring.ai.bedrock.titan.embedding.generative=MODEL_XYZ", - "spring.ai.bedrock.titan.embedding.inputType=TEXT") + new ApplicationContextRunner().withPropertyValues("spring.ai.bedrock.titan.embedding.enabled=true", + "spring.ai.bedrock.aws.access-key=ACCESS_KEY", "spring.ai.bedrock.aws.secret-key=SECRET_KEY", + "spring.ai.bedrock.aws.region=" + Region.EU_CENTRAL_1.id(), + "spring.ai.bedrock.titan.embedding.model=MODEL_XYZ", "spring.ai.bedrock.titan.embedding.inputType=TEXT") .withConfiguration(AutoConfigurations.of(BedrockTitanEmbeddingAutoConfiguration.class)) .run(context -> { var properties = context.getBean(BedrockTitanEmbeddingProperties.class); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfigurationIT.java index 45c04eb9e4e..0679da52dca 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfigurationIT.java @@ -70,7 +70,7 @@ public static void beforeAll() throws IOException, InterruptedException { } private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withPropertyValues("spring.ai.ollama.chat.enabled=true", "spring.ai.ollama.chat.generative=" + MODEL_NAME, + .withPropertyValues("spring.ai.ollama.chat.enabled=true", "spring.ai.ollama.chat.model=" + MODEL_NAME, "spring.ai.ollama.baseUrl=" + baseUrl, "spring.ai.ollama.chat.temperature=0.5", "spring.ai.ollama.chat.topK=10") .withConfiguration(AutoConfigurations.of(OllamaAutoConfiguration.class)); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfigurationTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfigurationTests.java index 37fe5f9ece1..d7465fc9828 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfigurationTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfigurationTests.java @@ -17,7 +17,7 @@ package org.springframework.ai.autoconfigure.ollama; import org.junit.jupiter.api.Test; -import org.springframework.ai.ollama.OllamaChatClient; + import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -35,7 +35,7 @@ public void propertiesTest() { new ApplicationContextRunner().withPropertyValues( // @formatter:off "spring.ai.ollama.base-url=TEST_BASE_URL", - "spring.ai.ollama.chat.generative=MODEL_XYZ", + "spring.ai.ollama.chat.model=MODEL_XYZ", "spring.ai.ollama.chat.options.temperature=0.55", "spring.ai.ollama.chat.options.topP=0.56", "spring.ai.ollama.chat.options.topK=123") diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationIT.java index 11f49d4144b..6cae39e9700 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationIT.java @@ -61,8 +61,7 @@ public static void beforeAll() throws IOException, InterruptedException { } private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withPropertyValues("spring.ai.ollama.embedding.generative=" + MODEL_NAME, - "spring.ai.ollama.base-url=" + baseUrl) + .withPropertyValues("spring.ai.ollama.embedding.model=" + MODEL_NAME, "spring.ai.ollama.base-url=" + baseUrl) .withConfiguration(AutoConfigurations.of(OllamaAutoConfiguration.class)); @Test diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationTests.java index e9976f11cf5..63d8287dc4e 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationTests.java @@ -17,7 +17,7 @@ package org.springframework.ai.autoconfigure.ollama; import org.junit.jupiter.api.Test; -import org.springframework.ai.ollama.OllamaEmbeddingClient; + import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/VertexAiAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/VertexAiAutoConfigurationIT.java index 44036f9743f..429e4d385fc 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/VertexAiAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/VertexAiAutoConfigurationIT.java @@ -39,8 +39,8 @@ public class VertexAiAutoConfigurationIT { private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withPropertyValues("spring.ai.vertex.ai.baseUrl=https://generativelanguage.googleapis.com/v1beta3", "spring.ai.vertex.ai.apiKey=" + System.getenv("PALM_API_KEY"), - "spring.ai.vertex.ai.chat.generative=chat-bison-001", - "spring.ai.vertex.ai.embedding.generative=embedding-gecko-001") + "spring.ai.vertex.ai.chat.model=chat-bison-001", + "spring.ai.vertex.ai.embedding.model=embedding-gecko-001") .withConfiguration(AutoConfigurations.of(VertexAiAutoConfiguration.class)); @Test